Compare commits

..

2 Commits

Author SHA1 Message Date
pablodanswer
56fd40e606 incorporate base default padding for modals 2024-12-09 14:17:40 -08:00
pablodanswer
415d644200 remove double x 2024-12-09 14:08:29 -08:00
135 changed files with 2919 additions and 3410 deletions

View File

@@ -1,57 +0,0 @@
"""delete_input_prompts
Revision ID: bf7a81109301
Revises: f7a894b06d02
Create Date: 2024-12-09 12:00:49.884228
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "bf7a81109301"
down_revision = "f7a894b06d02"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")
def downgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["inputprompt.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)

View File

@@ -1,4 +1,3 @@
import hashlib
import secrets
import uuid
from urllib.parse import quote
@@ -19,8 +18,7 @@ _API_KEY_HEADER_NAME = "Authorization"
# organizations like the Internet Engineering Task Force (IETF).
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
_BEARER_PREFIX = "Bearer "
_API_KEY_PREFIX = "on_"
_DEPRECATED_API_KEY_PREFIX = "dn_"
_API_KEY_PREFIX = "dn_"
_API_KEY_LEN = 192
@@ -54,9 +52,7 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
if not api_key.startswith(_API_KEY_PREFIX) and not api_key.startswith(
_DEPRECATED_API_KEY_PREFIX
):
if not api_key.startswith(_API_KEY_PREFIX):
return None
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
@@ -67,19 +63,10 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
return unquote(tenant_id) if tenant_id else None
def _deprecated_hash_api_key(api_key: str) -> str:
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
def hash_api_key(api_key: str) -> str:
# NOTE: no salt is needed, as the API key is randomly generated
# and overlaps are impossible
if api_key.startswith(_API_KEY_PREFIX):
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
return _deprecated_hash_api_key(api_key)
else:
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
def build_displayable_api_key(api_key: str) -> str:

View File

@@ -9,6 +9,7 @@ from danswer.utils.special_types import JSON_ro
def get_invited_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except KvKeyNotFoundError:
return list()

View File

@@ -131,7 +131,7 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
def user_needs_to_be_verified() -> bool:
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
if AUTH_TYPE == AuthType.BASIC:
return REQUIRE_EMAIL_VERIFICATION
# For other auth types, if the user is authenticated it's assumed that

View File

@@ -598,7 +598,7 @@ def connector_indexing_proxy_task(
db_session,
"Connector termination signal detected",
)
except Exception:
finally:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(

View File

@@ -680,28 +680,17 @@ def monitor_ccpair_indexing_taskset(
)
task_logger.warning(msg)
try:
index_attempt = get_index_attempt(
db_session, payload.index_attempt_id
)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
except Exception:
task_logger.exception(
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
return

View File

@@ -206,9 +206,7 @@ class Answer:
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
search_result, displayed_search_results_map = SearchTool.get_search_result(
current_llm_call
) or ([], {})
search_result = SearchTool.get_search_result(current_llm_call) or []
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
@@ -226,7 +224,6 @@ class Answer:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
display_doc_order_dict=displayed_search_results_map,
)
response_handler_manager = LLMResponseHandlerManager(

View File

@@ -35,18 +35,13 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.display_doc_order_dict = display_doc_order_dict
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
display_doc_order_dict=self.display_doc_order_dict,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []

View File

@@ -22,16 +22,12 @@ class CitationProcessor:
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.display_doc_order_dict = (
display_doc_order_dict # original order of docs to displayed to user
)
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
@@ -102,18 +98,6 @@ class CitationProcessor:
self.citation_order.index(real_citation_num) + 1
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_doc_order_dict:
displayed_citation_num = self.display_doc_order_dict[
context_llm_doc.document_id
]
else:
displayed_citation_num = real_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
start, end = citation.span()
@@ -134,7 +118,6 @@ class CitationProcessor:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
@@ -156,7 +139,6 @@ class CitationProcessor:
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
@@ -166,8 +148,7 @@ class CitationProcessor:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
# + f"[[{target_citation_num}]]({link})"
+ f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
@@ -175,8 +156,7 @@ class CitationProcessor:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
# + f"[[{target_citation_num}]]()"
+ f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length

View File

@@ -348,12 +348,6 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)
# Egnyte specific configs
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -417,28 +411,21 @@ LARGE_CHUNK_RATIO = 4
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(
os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT") or 3
)
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
INDEXING_SIZE_WARNING_THRESHOLD = int(
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
)
# during indexing, will log verbose memory diff stats every x batches and at the end.
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
# Maximum file size in a document to be indexed
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
#####
# Miscellaneous

View File

@@ -3,6 +3,7 @@ import os
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking

View File

@@ -132,7 +132,6 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
EGNYTE = "egnyte"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]

View File

@@ -1,384 +0,0 @@
import io
import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from logging import Logger
from typing import Any
from typing import cast
from typing import IO
import requests
from retry import retry
from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN
from danswer.configs.app_configs import EGNYTE_CLIENT_ID
from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET
from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import OAuthConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.extract_file_text import get_file_ext
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import is_valid_file_ext
from danswer.file_processing.extract_file_text import read_text_file
from danswer.utils.logger import setup_logger
logger = setup_logger()
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
_TIMEOUT = 60
def _request_with_retries(
method: str,
url: str,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = _TIMEOUT,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method,
url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code != 403:
logger.exception(
f"Failed to call Egnyte API.\n"
f"URL: {url}\n"
f"Headers: {headers}\n"
f"Data: {data}\n"
f"Params: {params}"
)
raise e
return response
return _make_request()
def _parse_last_modified(last_modified: str) -> datetime:
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
tzinfo=timezone.utc
)
def _process_egnyte_file(
file_metadata: dict[str, Any],
file_content: IO,
base_url: str,
folder_path: str | None = None,
) -> Document | None:
"""Process an Egnyte file into a Document object
Args:
file_data: The file data from Egnyte API
file_content: The raw content of the file in bytes
base_url: The base URL for the Egnyte instance
folder_path: Optional folder path to filter results
"""
# Skip if file path doesn't match folder path filter
if folder_path and not file_metadata["path"].startswith(folder_path):
raise ValueError(
f"File path {file_metadata['path']} does not match folder path {folder_path}"
)
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None
# Extract text content based on file type
if is_text_file_extension(file_name):
encoding = detect_encoding(file_content)
file_content_raw, file_metadata = read_text_file(
file_content, encoding=encoding, ignore_danswer_metadata=False
)
else:
file_content_raw = extract_file_text(
file=file_content,
file_name=file_name,
break_on_unprocessable=True,
)
# Build the web URL for the file
web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}"
# Create document metadata
metadata: dict[str, str | list[str]] = {
"file_path": file_metadata["path"],
"last_modified": file_metadata.get("last_modified", ""),
}
# Add lock info if present
if lock_info := file_metadata.get("lock_info"):
metadata[
"lock_owner"
] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}"
# Create the document owners
primary_owner = None
if uploaded_by := file_metadata.get("uploaded_by"):
primary_owner = BasicExpertInfo(
email=uploaded_by, # Using username as email since that's what we have
)
# Create the document
return Document(
id=f"egnyte-{file_metadata['entry_id']}",
sections=[Section(text=file_content_raw.strip(), link=web_url)],
source=DocumentSource.EGNYTE,
semantic_identifier=file_name,
metadata=metadata,
doc_updated_at=(
_parse_last_modified(file_metadata["last_modified"])
if "last_modified" in file_metadata
else None
),
primary_owners=[primary_owner] if primary_owner else None,
)
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
def __init__(
self,
folder_path: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.domain = "" # will always be set in `load_credentials`
self.folder_path = folder_path or "" # Root folder if not specified
self.batch_size = batch_size
self.access_token: str | None = None
@classmethod
def oauth_id(cls) -> DocumentSource:
return DocumentSource.EGNYTE
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
if EGNYTE_LOCALHOST_OVERRIDE:
base_domain = EGNYTE_LOCALHOST_OVERRIDE
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
f"&state={state}"
f"&response_type=code"
)
@classmethod
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte",
"scope": "Egnyte.filesystem",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = _request_with_retries(
method="POST",
url=url,
data=data,
headers=headers,
# try a lot faster since this is a realtime flow
backoff=0,
delay=0.1,
)
if not response.ok:
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"access_token": token_data["access_token"],
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.domain = credentials["domain"]
self.access_token = credentials["access_token"]
return None
def _get_files_list(
self,
path: str,
) -> list[dict[str, Any]]:
if not self.access_token or not self.domain:
raise ConnectorMissingCredentialError("Egnyte")
headers = {
"Authorization": f"Bearer {self.access_token}",
}
params: dict[str, Any] = {
"list_content": True,
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
response = _request_with_retries(
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
)
if not response.ok:
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
data = response.json()
all_files: list[dict[str, Any]] = []
# Add files from current directory
all_files.extend(data.get("files", []))
# Recursively traverse folders
for item in data.get("folders", []):
all_files.extend(self._get_files_list(item["path"]))
return all_files
def _filter_files(
self,
files: list[dict[str, Any]],
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> list[dict[str, Any]]:
filtered_files = []
for file in files:
if file["is_folder"]:
continue
file_modified = _parse_last_modified(file["last_modified"])
if start_time and file_modified < start_time:
continue
if end_time and file_modified > end_time:
continue
filtered_files.append(file)
return filtered_files
def _process_files(
self,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> Generator[list[Document], None, None]:
files = self._get_files_list(self.folder_path)
files = self._filter_files(files, start_time, end_time)
current_batch: list[Document] = []
for file in files:
try:
# Set up request with streaming enabled
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
response = _request_with_retries(
method="GET",
url=url,
headers=headers,
timeout=_TIMEOUT,
stream=True,
)
if not response.ok:
logger.error(
f"Failed to fetch file content: {file['path']} (status code: {response.status_code})"
)
continue
# Stream the response content into a BytesIO buffer
buffer = io.BytesIO()
for chunk in response.iter_content(chunk_size=8192):
if chunk:
buffer.write(chunk)
# Reset buffer's position to the start
buffer.seek(0)
# Process the streamed file content
doc = _process_egnyte_file(
file_metadata=file,
file_content=buffer,
base_url=_EGNYTE_APP_BASE.format(domain=self.domain),
folder_path=self.folder_path,
)
if doc is not None:
current_batch.append(doc)
if len(current_batch) >= self.batch_size:
yield current_batch
current_batch = []
except Exception:
logger.exception(f"Failed to process file {file['path']}")
continue
if current_batch:
yield current_batch
def load_from_state(self) -> GenerateDocumentsOutput:
yield from self._process_files()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
yield from self._process_files(start_time=start_time, end_time=end_time)
if __name__ == "__main__":
connector = EgnyteConnector()
connector.load_credentials(
{
"domain": os.environ["EGNYTE_DOMAIN"],
"access_token": os.environ["EGNYTE_ACCESS_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -15,7 +15,6 @@ from danswer.connectors.danswer_jira.connector import JiraConnector
from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.egnyte.connector import EgnyteConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.fireflies.connector import FirefliesConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
@@ -104,7 +103,6 @@ def identify_connector_class(
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -17,11 +17,11 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_session_with_tenant
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.extract_file_text import get_file_ext
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import is_valid_file_ext
from danswer.file_processing.extract_file_text import load_files_from_zip
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
@@ -50,7 +50,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), file, metadata
elif is_valid_file_ext(extension):
elif check_file_ext_is_valid(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -63,7 +63,7 @@ def _process_file(
pdf_pass: str | None = None,
) -> list[Document]:
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
if not check_file_ext_is_valid(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return []

View File

@@ -4,13 +4,11 @@ from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import MAX_FILE_SIZE_BYTES
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.doc_conversion import build_slim_document
from danswer.connectors.google_drive.doc_conversion import (
@@ -454,14 +452,12 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if isinstance(self.creds, ServiceAccountCredentials)
else self._manage_oauth_retrieval
)
drive_files = retrieval_method(
return retrieval_method(
is_slim=is_slim,
start=start,
end=end,
)
return drive_files
def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
@@ -477,15 +473,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
if (
file.get("size")
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
):
logger.warning(
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
)
continue
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(

View File

@@ -16,7 +16,7 @@ logger = setup_logger()
FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
"shortcutDetails, owners(emailAddress), size)"
"shortcutDetails, owners(emailAddress))"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "

View File

@@ -2,7 +2,6 @@ import abc
from collections.abc import Iterator
from typing import Any
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import SlimDocument
@@ -65,23 +64,6 @@ class SlimConnector(BaseConnector):
raise NotImplementedError
class OAuthConnector(BaseConnector):
@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod

View File

@@ -132,6 +132,7 @@ class LinearConnector(LoadConnector, PollConnector):
branchName
customerTicketCount
description
descriptionData
comments {
nodes {
url
@@ -214,6 +215,5 @@ class LinearConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
connector = LinearConnector()
connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]})
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -171,9 +171,7 @@ def thread_to_doc(
else first_message
)
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
"\n", " "
)
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}"
return Document(
id=f"{channel_id}__{thread[0]['ts']}",

View File

@@ -204,8 +204,7 @@ def _build_documents_blocks(
continue
seen_docs_identifiers.add(d.document_id)
# Strip newlines from the semantic identifier for Slackbot formatting
doc_sem_id = d.semantic_identifier.replace("\n", " ")
doc_sem_id = d.semantic_identifier
if d.source_type == DocumentSource.SLACK.value:
doc_sem_id = "#" + doc_sem_id

View File

@@ -373,9 +373,7 @@ def handle_regular_answer(
respond_in_thread(
client=client,
channel=channel,
receiver_ids=[message_info.sender]
if message_info.is_bot_msg and message_info.sender
else receiver_ids,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,

View File

@@ -11,7 +11,6 @@ from retry import retry
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import Block
from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.metadata import Metadata
from slack_sdk.socket_mode import SocketModeClient
@@ -141,40 +140,6 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
def _check_for_url_in_block(block: Block) -> bool:
"""
Check if the block has a key that contains "url" in it
"""
block_dict = block.to_dict()
def check_dict_for_url(d: dict) -> bool:
for key, value in d.items():
if "url" in key.lower():
return True
if isinstance(value, dict):
if check_dict_for_url(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_dict_for_url(item):
return True
return False
return check_dict_for_url(block_dict)
def _build_error_block(error_message: str) -> Block:
"""
Build an error block to display in slack so that the user can see
the error without completely breaking
"""
display_text = (
"There was an error displaying all of the Onyx answers."
f" Please let an admin or an onyx developer know. Error: {error_message}"
)
return SectionBlock(text=display_text)
@retry(
tries=DANSWER_BOT_NUM_RETRIES,
delay=0.25,
@@ -197,9 +162,24 @@ def respond_in_thread(
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
try:
response = slack_call(
channel=channel,
text=text,
blocks=blocks,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks,
thread_ts=thread_ts,
@@ -207,68 +187,8 @@ def respond_in_thread(
unfurl_links=unfurl,
unfurl_media=unfurl,
)
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
raise e
blocks_without_urls = [
block for block in blocks if not _check_for_url_in_block(block)
]
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
channel=channel,
text=text,
blocks=blocks_without_urls,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
raise e
blocks_without_urls = [
block for block in blocks if not _check_for_url_in_block(block)
]
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks_without_urls,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
return message_ids

View File

@@ -20,6 +20,7 @@ from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import CredentialDataUpdateRequest
from danswer.utils.logger import setup_logger
@@ -261,8 +262,7 @@ def _cleanup_credential__user_group_relationships__no_commit(
def alter_credential(
credential_id: int,
name: str,
credential_json: dict[str, Any],
credential_data: CredentialDataUpdateRequest,
user: User,
db_session: Session,
) -> Credential | None:
@@ -272,13 +272,11 @@ def alter_credential(
if credential is None:
return None
credential.name = name
credential.name = credential_data.name
# Assign a new dictionary to credential.credential_json
credential.credential_json = {
**credential.credential_json,
**credential_json,
}
# Update only the keys present in credential_data.credential_json
for key, value in credential_data.credential_json.items():
credential.credential_json[key] = value
credential.user_id = user.id if user is not None else None
db_session.commit()
@@ -311,8 +309,8 @@ def update_credential_json(
credential = fetch_credential_by_id(credential_id, user, db_session)
if credential is None:
return None
credential.credential_json = credential_json
db_session.commit()
return credential

View File

@@ -522,16 +522,12 @@ def expire_index_attempts(
search_settings_id: int,
db_session: Session,
) -> None:
not_started_query = (
update(IndexAttempt)
delete_query = (
delete(IndexAttempt)
.where(IndexAttempt.search_settings_id == search_settings_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
.values(
status=IndexingStatus.CANCELED,
error_msg="Canceled, likely due to model swap",
)
)
db_session.execute(not_started_query)
db_session.execute(delete_query)
update_query = (
update(IndexAttempt)
@@ -553,14 +549,9 @@ def cancel_indexing_attempts_for_ccpair(
include_secondary_index: bool = False,
) -> None:
stmt = (
update(IndexAttempt)
delete(IndexAttempt)
.where(IndexAttempt.connector_credential_pair_id == cc_pair_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
.values(
status=IndexingStatus.CANCELED,
error_msg="Canceled by user",
time_started=datetime.now(timezone.utc),
)
)
if not include_secondary_index:

View File

@@ -0,0 +1,202 @@
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import InputPrompt
from danswer.db.models import User
from danswer.server.features.input_prompt.models import InputPromptSnapshot
from danswer.server.manage.models import UserInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
def insert_input_prompt_if_not_exists(
user: User | None,
input_prompt_id: int | None,
prompt: str,
content: str,
active: bool,
is_public: bool,
db_session: Session,
commit: bool = True,
) -> InputPrompt:
if input_prompt_id is not None:
input_prompt = (
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
)
else:
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
if user:
query = query.filter(InputPrompt.user_id == user.id)
else:
query = query.filter(InputPrompt.user_id.is_(None))
input_prompt = query.first()
if input_prompt is None:
input_prompt = InputPrompt(
id=input_prompt_id,
prompt=prompt,
content=content,
active=active,
is_public=is_public or user is None,
user_id=user.id if user else None,
)
db_session.add(input_prompt)
if commit:
db_session.commit()
return input_prompt
def insert_input_prompt(
prompt: str,
content: str,
is_public: bool,
user: User | None,
db_session: Session,
) -> InputPrompt:
input_prompt = InputPrompt(
prompt=prompt,
content=content,
active=True,
is_public=is_public or user is None,
user_id=user.id if user is not None else None,
)
db_session.add(input_prompt)
db_session.commit()
return input_prompt
def update_input_prompt(
user: User | None,
input_prompt_id: int,
prompt: str,
content: str,
active: bool,
db_session: Session,
) -> InputPrompt:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You don't own this prompt")
input_prompt.prompt = prompt
input_prompt.content = content
input_prompt.active = active
db_session.commit()
return input_prompt
def validate_user_prompt_authorization(
user: User | None, input_prompt: InputPrompt
) -> bool:
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
if prompt.user_id is not None:
if user is None:
return False
user_details = UserInfo.from_model(user)
if str(user_details.id) != str(prompt.user_id):
return False
return True
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not input_prompt.is_public:
raise HTTPException(status_code=400, detail="This prompt is not public")
db_session.delete(input_prompt)
db_session.commit()
def remove_input_prompt(
user: User | None, input_prompt_id: int, db_session: Session
) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if input_prompt.is_public:
raise HTTPException(
status_code=400, detail="Cannot delete public prompts with this method"
)
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You do not own this prompt")
db_session.delete(input_prompt)
db_session.commit()
def fetch_input_prompt_by_id(
id: int, user_id: UUID | None, db_session: Session
) -> InputPrompt:
query = select(InputPrompt).where(InputPrompt.id == id)
if user_id:
query = query.where(
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
)
else:
# If no user_id is provided, only fetch prompts without a user_id (aka public)
query = query.where(InputPrompt.user_id == None) # noqa
result = db_session.scalar(query)
if result is None:
raise HTTPException(422, "No input prompt found")
return result
def fetch_public_input_prompts(
db_session: Session,
) -> list[InputPrompt]:
query = select(InputPrompt).where(InputPrompt.is_public)
return list(db_session.scalars(query).all())
def fetch_input_prompts_by_user(
db_session: Session,
user_id: UUID | None,
active: bool | None = None,
include_public: bool = False,
) -> list[InputPrompt]:
query = select(InputPrompt)
if user_id is not None:
if include_public:
query = query.where(
(InputPrompt.user_id == user_id) | InputPrompt.is_public
)
else:
query = query.where(InputPrompt.user_id == user_id)
elif include_public:
query = query.where(InputPrompt.is_public)
if active is not None:
query = query.where(InputPrompt.active == active)
return list(db_session.scalars(query).all())

View File

@@ -159,6 +159,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
input_prompts: Mapped[list["InputPrompt"]] = relationship(
"InputPrompt", back_populates="user"
)
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
@@ -175,6 +178,31 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
class InputPrompt(Base):
__tablename__ = "inputprompt"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
prompt: Mapped[str] = mapped_column(String)
content: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
class InputPrompt__User(Base):
__tablename__ = "inputprompt__user"
input_prompt_id: Mapped[int] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
pass
@@ -568,25 +596,6 @@ class Connector(Base):
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
# synchronize this validation logic with RefreshFrequencySchema etc on front end
# until we have a centralized validation schema
# TODO(rkuo): experiment with SQLAlchemy validators rather than manual checks
# https://docs.sqlalchemy.org/en/20/orm/mapped_attributes.html
def validate_refresh_freq(self) -> None:
if self.refresh_freq is not None:
if self.refresh_freq < 60:
raise ValueError(
"refresh_freq must be greater than or equal to 60 seconds."
)
def validate_prune_freq(self) -> None:
if self.prune_freq is not None:
if self.prune_freq < 86400:
raise ValueError(
"prune_freq must be greater than or equal to 86400 seconds."
)
class Credential(Base):
__tablename__ = "credential"

View File

@@ -70,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str:
return extension
def is_valid_file_ext(ext: str) -> bool:
def check_file_ext_is_valid(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
@@ -364,7 +364,7 @@ def extract_file_text(
elif file_name is not None:
final_extension = get_file_ext(file_name)
if is_valid_file_ext(final_extension):
if check_file_ext_is_valid(final_extension):
return extension_to_function.get(final_extension, file_io_to_text)(file)
# Either the file somehow has no name or the extension is not one that we recognize

View File

@@ -1,5 +1,4 @@
import traceback
from collections.abc import Callable
from functools import partial
from http import HTTPStatus
from typing import Protocol
@@ -13,7 +12,6 @@ from danswer.access.access import get_access_for_documents
from danswer.access.models import DocumentAccess
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
from danswer.configs.constants import DEFAULT_BOOST
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
@@ -204,13 +202,40 @@ def index_doc_batch_with_handler(
def index_doc_batch_prepare(
documents: list[Document],
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
and empty_contents
):
# Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
continue
documents.append(document)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]
@@ -257,64 +282,17 @@ def index_doc_batch_prepare(
)
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
and empty_contents
):
# Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
continue
section_chars = sum(len(section.text) for section in document.sections)
if (
MAX_DOCUMENT_CHARS
and len(document.title or document.semantic_identifier) + section_chars
> MAX_DOCUMENT_CHARS
):
# Skip documents that are too long, later on there are more memory intensive steps done on the text
# and the container will run out of memory and crash. Several other checks are included upstream but
# those are at the connector level so a catchall is still needed.
# Assumption here is that files that are that long, are generated files and not the type users
# generally care for.
logger.warning(
f"Skipping document with ID {document.id} as it is too long."
)
continue
documents.append(document)
return documents
@log_function_time(debug_only=True)
def index_doc_batch(
*,
document_batch: list[Document],
chunker: Chunker,
embedder: IndexingEmbedder,
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
@@ -331,11 +309,8 @@ def index_doc_batch(
is_public=False,
)
logger.debug("Filtering Documents")
filtered_documents = filter_fnc(document_batch)
ctx = index_doc_batch_prepare(
documents=filtered_documents,
document_batch=document_batch,
index_attempt_metadata=index_attempt_metadata,
ignore_time_skip=ignore_time_skip,
db_session=db_session,

View File

@@ -1,4 +1,5 @@
import copy
import io
import json
from collections.abc import Callable
from collections.abc import Iterator
@@ -6,6 +7,7 @@ from typing import Any
from typing import cast
import litellm # type: ignore
import pandas as pd
import tiktoken
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
@@ -98,32 +100,53 @@ def litellm_exception_to_error_msg(
return error_msg
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
csv_preview = df.head().to_string(max_cols=max_columns)
file_name_section = (
f"CSV FILE NAME: {file.filename}\n"
if file.filename
else "CSV FILE (NO NAME PROVIDED):\n"
)
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
def _build_content(
message: str,
files: list[InMemoryChatFile] | None = None,
) -> str:
"""Applies all non-image files."""
if not files:
return message
text_files = (
[file for file in files if file.file_type == ChatFileType.PLAIN_TEXT]
if files
else None
)
text_files = [
file
for file in files
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
]
csv_files = (
[file for file in files if file.file_type == ChatFileType.CSV]
if files
else None
)
if not text_files:
if not text_files and not csv_files:
return message
final_message_with_files = "FILES:\n\n"
for file in text_files:
for file in text_files or []:
file_content = file.content.decode("utf-8")
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
final_message_with_files += (
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
)
for file in csv_files or []:
final_message_with_files += _process_csv_file(file)
return final_message_with_files + message
final_message_with_files += message
return final_message_with_files
def build_content_with_imgs(

View File

@@ -52,9 +52,12 @@ from danswer.server.documents.connector import router as connector_router
from danswer.server.documents.credential import router as credential_router
from danswer.server.documents.document import router as document_router
from danswer.server.documents.indexing import router as indexing_router
from danswer.server.documents.standard_oauth import router as oauth_router
from danswer.server.features.document_set.api import router as document_set_router
from danswer.server.features.folder.api import router as folder_router
from danswer.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
from danswer.server.features.notifications.api import router as notification_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
@@ -255,6 +258,8 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
@@ -277,7 +282,6 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, oauth_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@@ -0,0 +1,24 @@
input_prompts:
- id: -5
prompt: "Elaborate"
content: "Elaborate on the above, give me a more in depth explanation."
active: true
is_public: true
- id: -4
prompt: "Reword"
content: "Help me rewrite the following politely and concisely for professional communication:\n"
active: true
is_public: true
- id: -3
prompt: "Email"
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
active: true
is_public: true
- id: -2
prompt: "Debug"
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
active: true
is_public: true

View File

@@ -196,7 +196,7 @@ def seed_initial_documents(
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)
index_doc_batch_prepare(
documents=docs,
document_batch=docs,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=connector_id,
credential_id=PUBLIC_CREDENTIAL_ID,

View File

@@ -1,11 +1,13 @@
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import Prompt as PromptDBModel
@@ -138,10 +140,35 @@ def load_personas_from_yaml(
)
def load_input_prompts_from_yaml(
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)

View File

@@ -33,6 +33,8 @@ from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
@@ -43,7 +45,6 @@ from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCPropertyUpdateRequest
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairMetadata
@@ -191,6 +192,9 @@ def update_cc_pair_status(
db_session
)
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
cancel_indexing_attempts_past_model(db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
@@ -304,46 +308,6 @@ def update_cc_pair_name(
raise HTTPException(status_code=400, detail="Name must be unique")
@router.put("/admin/cc-pair/{cc_pair_id}/property")
def update_cc_pair_property(
cc_pair_id: int,
update_request: CCPropertyUpdateRequest, # in seconds
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=True,
)
if not cc_pair:
raise HTTPException(
status_code=400, detail="CC Pair not found for current user's permissions"
)
# Can we centralize logic for updating connector properties
# so that we don't need to manually validate everywhere?
if update_request.name == "refresh_frequency":
cc_pair.connector.refresh_freq = int(update_request.value)
cc_pair.connector.validate_refresh_freq()
db_session.commit()
msg = "Refresh frequency updated successfully"
elif update_request.name == "pruning_frequency":
cc_pair.connector.prune_freq = int(update_request.value)
cc_pair.connector.validate_prune_freq()
db_session.commit()
msg = "Pruning frequency updated successfully"
else:
raise HTTPException(
status_code=400, detail=f"Property name {update_request.name} is not valid."
)
return StatusResponse(success=True, message=msg, data=cc_pair_id)
@router.get("/admin/cc-pair/{cc_pair_id}/last_pruned")
def get_cc_pair_last_pruned(
cc_pair_id: int,

View File

@@ -181,13 +181,7 @@ def update_credential_data(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CredentialBase:
credential = alter_credential(
credential_id,
credential_update.name,
credential_update.credential_json,
user,
db_session,
)
credential = alter_credential(credential_id, credential_update, user, db_session)
if credential is None:
raise HTTPException(

View File

@@ -364,11 +364,6 @@ class RunConnectorRequest(BaseModel):
from_beginning: bool = False
class CCPropertyUpdateRequest(BaseModel):
name: str
value: str
"""Connectors Models"""

View File

@@ -1,142 +0,0 @@
import uuid
from typing import Annotated
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import OAuthConnector
from danswer.db.credentials import create_credential
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CredentialBase
from danswer.utils.logger import setup_logger
from danswer.utils.subclasses import find_all_subclasses_in_dir
logger = setup_logger()
router = APIRouter(prefix="/connector/oauth")
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
# Cache for OAuth connectors, populated at module load time
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
"""Walk through the connectors package to find all OAuthConnector implementations"""
global _OAUTH_CONNECTORS
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
return _OAUTH_CONNECTORS
oauth_connectors = find_all_subclasses_in_dir(
cast(type[OAuthConnector], OAuthConnector), "danswer.connectors"
)
_OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors}
return _OAUTH_CONNECTORS
# Discover OAuth connectors at module load time
_discover_oauth_connectors()
class AuthorizeResponse(BaseModel):
redirect_url: str
@router.get("/authorize/{source}")
def oauth_authorize(
source: DocumentSource,
desired_return_url: Annotated[str | None, Query()] = None,
_: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> AuthorizeResponse:
"""Initiates the OAuth flow by redirecting to the provider's auth page"""
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
connector_cls = oauth_connectors[source]
base_url = WEB_DOMAIN
# store state in redis
if not desired_return_url:
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
redis_client = get_redis_client(tenant_id=tenant_id)
state = str(uuid.uuid4())
redis_client.set(
_OAUTH_STATE_KEY_FMT.format(state=state),
desired_return_url,
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
)
return AuthorizeResponse(
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
)
class CallbackResponse(BaseModel):
redirect_url: str
@router.get("/callback/{source}")
def oauth_callback(
source: DocumentSource,
code: Annotated[str, Query()],
state: Annotated[str, Query()],
db_session: Session = Depends(get_session),
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CallbackResponse:
"""Handles the OAuth callback and exchanges the code for tokens"""
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
connector_cls = oauth_connectors[source]
# get state from redis
redis_client = get_redis_client(tenant_id=tenant_id)
original_url_bytes = cast(
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
)
if not original_url_bytes:
raise HTTPException(status_code=400, detail="Invalid OAuth state")
original_url = original_url_bytes.decode("utf-8")
token_info = connector_cls.oauth_code_to_token(code)
# Create a new credential with the token info
credential_data = CredentialBase(
credential_json=token_info,
admin_public=True, # Or based on some logic/parameter
source=source,
name=f"{source.title()} OAuth Credential",
)
credential = create_credential(
credential_data=credential_data,
user=user,
db_session=db_session,
)
return CallbackResponse(
redirect_url=(
f"{original_url}?credentialId={credential.id}"
if "?" not in original_url
else f"{original_url}&credentialId={credential.id}"
)
)

View File

@@ -0,0 +1,134 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.input_prompt import fetch_input_prompt_by_id
from danswer.db.input_prompt import fetch_input_prompts_by_user
from danswer.db.input_prompt import fetch_public_input_prompts
from danswer.db.input_prompt import insert_input_prompt
from danswer.db.input_prompt import remove_input_prompt
from danswer.db.input_prompt import remove_public_input_prompt
from danswer.db.input_prompt import update_input_prompt
from danswer.db.models import User
from danswer.server.features.input_prompt.models import CreateInputPromptRequest
from danswer.server.features.input_prompt.models import InputPromptSnapshot
from danswer.server.features.input_prompt.models import UpdateInputPromptRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
basic_router = APIRouter(prefix="/input_prompt")
admin_router = APIRouter(prefix="/admin/input_prompt")
@basic_router.get("")
def list_input_prompts(
user: User | None = Depends(current_user),
include_public: bool = False,
db_session: Session = Depends(get_session),
) -> list[InputPromptSnapshot]:
user_prompts = fetch_input_prompts_by_user(
user_id=user.id if user is not None else None,
db_session=db_session,
include_public=include_public,
)
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
@basic_router.get("/{input_prompt_id}")
def get_input_prompt(
input_prompt_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> InputPromptSnapshot:
input_prompt = fetch_input_prompt_by_id(
id=input_prompt_id,
user_id=user.id if user is not None else None,
db_session=db_session,
)
return InputPromptSnapshot.from_model(input_prompt=input_prompt)
@basic_router.post("")
def create_input_prompt(
create_input_prompt_request: CreateInputPromptRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> InputPromptSnapshot:
input_prompt = insert_input_prompt(
prompt=create_input_prompt_request.prompt,
content=create_input_prompt_request.content,
is_public=create_input_prompt_request.is_public,
user=user,
db_session=db_session,
)
return InputPromptSnapshot.from_model(input_prompt)
@basic_router.patch("/{input_prompt_id}")
def patch_input_prompt(
input_prompt_id: int,
update_input_prompt_request: UpdateInputPromptRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> InputPromptSnapshot:
try:
updated_input_prompt = update_input_prompt(
user=user,
input_prompt_id=input_prompt_id,
prompt=update_input_prompt_request.prompt,
content=update_input_prompt_request.content,
active=update_input_prompt_request.active,
db_session=db_session,
)
except ValueError as e:
error_msg = "Error occurred while updated input prompt"
logger.warn(f"{error_msg}. Stack trace: {e}")
raise HTTPException(status_code=404, detail=error_msg)
return InputPromptSnapshot.from_model(updated_input_prompt)
@basic_router.delete("/{input_prompt_id}")
def delete_input_prompt(
input_prompt_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
remove_input_prompt(user, input_prompt_id, db_session)
except ValueError as e:
error_msg = "Error occurred while deleting input prompt"
logger.warn(f"{error_msg}. Stack trace: {e}")
raise HTTPException(status_code=404, detail=error_msg)
@admin_router.delete("/{input_prompt_id}")
def delete_public_input_prompt(
input_prompt_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
remove_public_input_prompt(input_prompt_id, db_session)
except ValueError as e:
error_msg = "Error occurred while deleting input prompt"
logger.warn(f"{error_msg}. Stack trace: {e}")
raise HTTPException(status_code=404, detail=error_msg)
@admin_router.get("")
def list_public_input_prompts(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[InputPromptSnapshot]:
user_prompts = fetch_public_input_prompts(
db_session=db_session,
)
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]

View File

@@ -0,0 +1,47 @@
from uuid import UUID
from pydantic import BaseModel
from danswer.db.models import InputPrompt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class CreateInputPromptRequest(BaseModel):
prompt: str
content: str
is_public: bool
class UpdateInputPromptRequest(BaseModel):
prompt: str
content: str
active: bool
class InputPromptResponse(BaseModel):
id: int
prompt: str
content: str
active: bool
class InputPromptSnapshot(BaseModel):
id: int
prompt: str
content: str
active: bool
user_id: UUID | None
is_public: bool
@classmethod
def from_model(cls, input_prompt: InputPrompt) -> "InputPromptSnapshot":
return InputPromptSnapshot(
id=input_prompt.id,
prompt=input_prompt.prompt,
content=input_prompt.content,
active=input_prompt.active,
user_id=input_prompt.user_id,
is_public=input_prompt.is_public,
)

View File

@@ -266,7 +266,5 @@ class FullModelVersionResponse(BaseModel):
class AllUsersResponse(BaseModel):
accepted: list[FullUserSnapshot]
invited: list[InvitedUserSnapshot]
slack_users: list[FullUserSnapshot]
accepted_pages: int
invited_pages: int
slack_users_pages: int

View File

@@ -119,7 +119,6 @@ def set_user_role(
def list_all_users(
q: str | None = None,
accepted_page: int | None = None,
slack_users_page: int | None = None,
invited_page: int | None = None,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
@@ -132,12 +131,7 @@ def list_all_users(
for user in list_users(db_session, email_filter_string=q)
if not is_api_key_email_address(user.email)
]
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
accepted_emails = {user.email for user in accepted_users}
slack_users_emails = {user.email for user in slack_users}
accepted_emails = {user.email for user in users}
invited_emails = get_invited_users()
if q:
invited_emails = [
@@ -145,11 +139,10 @@ def list_all_users(
]
accepted_count = len(accepted_emails)
slack_users_count = len(slack_users_emails)
invited_count = len(invited_emails)
# If any of q, accepted_page, or invited_page is None, return all users
if accepted_page is None or invited_page is None or slack_users_page is None:
if accepted_page is None or invited_page is None:
return AllUsersResponse(
accepted=[
FullUserSnapshot(
@@ -160,23 +153,11 @@ def list_all_users(
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
),
)
for user in accepted_users
],
slack_users=[
FullUserSnapshot(
id=user.id,
email=user.email,
role=user.role,
status=(
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
),
)
for user in slack_users
for user in users
],
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
accepted_pages=1,
invited_pages=1,
slack_users_pages=1,
)
# Otherwise, return paginated results
@@ -188,27 +169,13 @@ def list_all_users(
role=user.role,
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
)
for user in accepted_users
for user in users
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
slack_users=[
FullUserSnapshot(
id=user.id,
email=user.email,
role=user.role,
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
)
for user in slack_users
][
slack_users_page
* USERS_PAGE_SIZE : (slack_users_page + 1)
* USERS_PAGE_SIZE
],
invited=[InvitedUserSnapshot(email=email) for email in invited_emails][
invited_page * USERS_PAGE_SIZE : (invited_page + 1) * USERS_PAGE_SIZE
],
accepted_pages=accepted_count // USERS_PAGE_SIZE + 1,
invited_pages=invited_count // USERS_PAGE_SIZE + 1,
slack_users_pages=slack_users_count // USERS_PAGE_SIZE + 1,
)

View File

@@ -39,6 +39,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.seeding.load_docs import seed_initial_documents
from danswer.seeding.load_yamls import load_chat_yamls
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.settings.store import load_settings
@@ -150,7 +151,7 @@ def setup_danswer(
# update multipass indexing setting based on GPU availability
update_default_multipass_indexing(db_session)
# seed_initial_documents(db_session, tenant_id, cohere_enabled)
seed_initial_documents(db_session, tenant_id, cohere_enabled)
def translate_saved_search_settings(db_session: Session) -> None:

View File

@@ -48,9 +48,6 @@ from danswer.tools.tool_implementations.search_like_tool_utils import (
from danswer.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search_like_tool_utils import (
ORIGINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.utils.logger import setup_logger
from danswer.utils.special_types import JSON_ro
@@ -394,35 +391,15 @@ class SearchTool(Tool):
"""Other utility functions"""
@classmethod
def get_search_result(
cls, llm_call: LLMCall
) -> tuple[list[LlmDoc], dict[str, int]] | None:
"""
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
"""
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
if not llm_call.tool_call_info:
return None
final_search_results = []
doc_id_to_original_search_rank_map = {}
for yield_item in llm_call.tool_call_info:
if (
isinstance(yield_item, ToolResponse)
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
):
final_search_results = cast(list[LlmDoc], yield_item.response)
elif (
isinstance(yield_item, ToolResponse)
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
):
search_contexts = yield_item.response.contexts
original_doc_search_rank = 1
for idx, doc in enumerate(search_contexts):
if doc.document_id not in doc_id_to_original_search_rank_map:
doc_id_to_original_search_rank_map[
doc.document_id
] = original_doc_search_rank
original_doc_search_rank += 1
return cast(list[LlmDoc], yield_item.response)
return final_search_results, doc_id_to_original_search_rank_map
return None

View File

@@ -15,7 +15,6 @@ from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"

View File

@@ -1,77 +0,0 @@
from __future__ import annotations
import importlib
import os
import pkgutil
import sys
from types import ModuleType
from typing import List
from typing import Type
from typing import TypeVar
T = TypeVar("T")
def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]:
"""
Imports all modules found in the given directory and its subdirectories,
returning a list of imported module objects.
"""
dir_path = os.path.abspath(dir_path)
if dir_path not in sys.path:
sys.path.insert(0, dir_path)
imported_modules: List[ModuleType] = []
for _, package_name, _ in pkgutil.walk_packages([dir_path]):
try:
module = importlib.import_module(package_name)
imported_modules.append(module)
except Exception as e:
# Handle or log exceptions as needed
print(f"Could not import {package_name}: {e}")
return imported_modules
def all_subclasses(cls: Type[T]) -> List[Type[T]]:
"""
Recursively find all subclasses of the given class.
"""
direct_subs = cls.__subclasses__()
result: List[Type[T]] = []
for subclass in direct_subs:
result.append(subclass)
# Extend the result by recursively calling all_subclasses
result.extend(all_subclasses(subclass))
return result
def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Type[T]]:
"""
Imports all modules from the given directory (and subdirectories),
then returns all classes that are subclasses of parent_class.
:param parent_class: The class to find subclasses of.
:param directory: The directory to search for subclasses.
:return: A list of all subclasses of parent_class found in the directory.
"""
# First import all modules to ensure classes are loaded into memory
import_all_modules_from_dir(directory)
# Gather all subclasses of the given parent class
subclasses = all_subclasses(parent_class)
return subclasses
# Example usage:
if __name__ == "__main__":
class Animal:
pass
# Suppose "mymodules" contains files that define classes inheriting from Animal
found_subclasses = find_all_subclasses_in_dir(Animal, "mymodules")
for sc in found_subclasses:
print("Found subclass:", sc.__name__)

View File

@@ -76,7 +76,7 @@ def replace_user__ext_group_for_cc_pair(
new_external_permissions = []
for external_group in group_defs:
for user_email in external_group.user_emails:
user_id = email_id_map.get(user_email.lower())
user_id = email_id_map.get(user_email)
if user_id is None:
logger.warning(
f"User in group {external_group.id}"

View File

@@ -1,6 +1,4 @@
import asyncio
import json
from types import TracebackType
from typing import cast
from typing import Optional
@@ -8,11 +6,11 @@ import httpx
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import aembedding
from litellm import embedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
@@ -65,31 +63,22 @@ class CloudEmbedding:
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
) -> None:
self.provider = provider
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
async def _embed_openai(
self, texts: list[str], model: str | None
) -> list[Embedding]:
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
# Use the OpenAI specific timeout for this one
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(input=text_batch, model=model)
response = client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
@@ -104,19 +93,19 @@ class CloudEmbedding:
logger.error(error_string)
raise RuntimeError(error_string)
async def _embed_cohere(
def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_COHERE_MODEL
client = CohereAsyncClient(api_key=self.api_key)
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = await client.embed(
response = client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
@@ -125,29 +114,26 @@ class CloudEmbedding:
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings
async def _embed_voyage(
def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VOYAGE_MODEL
client = voyageai.AsyncClient(
client = voyageai.Client(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)
response = await client.embed(
texts=texts,
response = client.embed(
texts,
model=model,
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
self, texts: list[str], model: str | None
) -> list[Embedding]:
response = await aembedding(
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
response = embedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
@@ -156,9 +142,10 @@ class CloudEmbedding:
api_version=self.api_version,
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
async def _embed_vertex(
def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
@@ -171,7 +158,7 @@ class CloudEmbedding:
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
embeddings = await client.get_embeddings_async(
embeddings = client.get_embeddings(
[
TextEmbeddingInput(
text,
@@ -179,11 +166,11 @@ class CloudEmbedding:
)
for text in texts
],
auto_truncate=True, # This is the default
auto_truncate=True, # Also this is default
)
return [embedding.values for embedding in embeddings]
async def _embed_litellm_proxy(
def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
@@ -196,20 +183,22 @@ class CloudEmbedding:
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)
response = await self.http_client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
with httpx.Client() as client:
response = client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
timeout=API_BASED_EMBEDDING_TIMEOUT,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
async def embed(
def embed(
self,
*,
texts: list[str],
@@ -218,19 +207,19 @@ class CloudEmbedding:
deployment_name: str | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name)
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
return self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@@ -244,30 +233,6 @@ class CloudEmbedding:
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version)
async def aclose(self) -> None:
"""Explicitly close the client."""
if not self._closed:
await self.http_client.aclose()
self._closed = True
async def __aenter__(self) -> "CloudEmbedding":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""Finalizer to warn about unclosed clients."""
if not self._closed:
logger.warning(
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
)
def get_embedding_model(
model_name: str,
@@ -277,6 +242,9 @@ def get_embedding_model(
global _GLOBAL_MODELS_DICT # A dictionary to store models
if _GLOBAL_MODELS_DICT is None:
_GLOBAL_MODELS_DICT = {}
if model_name not in _GLOBAL_MODELS_DICT:
logger.notice(f"Loading {model_name}")
# Some model architectures that aren't built into the Transformers or Sentence
@@ -307,7 +275,7 @@ def get_local_reranking_model(
@simple_log_function_time()
async def embed_text(
def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
@@ -343,18 +311,18 @@ async def embed_text(
"Cloud models take an explicit text type instead."
)
async with CloudEmbedding(
cloud_model = CloudEmbedding(
api_key=api_key,
provider=provider_type,
api_url=api_url,
api_version=api_version,
) as cloud_model:
embeddings = await cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)
)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
@@ -370,12 +338,8 @@ async def embed_text(
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
# Run CPU-bound embedding in a thread pool
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
None,
lambda: local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
),
embeddings_vectors = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
)
embeddings = [
embedding if isinstance(embedding, list) else embedding.tolist()
@@ -393,31 +357,27 @@ async def embed_text(
@simple_log_function_time()
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
cross_encoder = get_local_reranking_model(model_name)
# Run CPU-bound reranking in a thread pool
return await asyncio.get_event_loop().run_in_executor(
None,
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
)
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
async def cohere_rerank(
def cohere_rerank(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
cohere_client = CohereClient(api_key=api_key)
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
async def litellm_rerank(
def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
async with httpx.AsyncClient() as client:
response = await client.post(
with httpx.Client() as client:
response = client.post(
api_url,
json={
"model": model_name,
@@ -451,7 +411,7 @@ async def process_embed_request(
else:
prefix = None
embeddings = await embed_text(
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
@@ -491,7 +451,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
try:
if rerank_request.provider_type is None:
sim_scores = await local_rerank(
sim_scores = local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
@@ -501,7 +461,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = await litellm_rerank(
sim_scores = litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
@@ -514,7 +474,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = await cohere_rerank(
sim_scores = cohere_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,

View File

@@ -6,12 +6,12 @@ router = APIRouter(prefix="/api")
@router.get("/health")
async def healthcheck() -> Response:
def healthcheck() -> Response:
return Response(status_code=200)
@router.get("/gpu-status")
async def gpu_status() -> dict[str, bool | str]:
def gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available():
return {"gpu_available": True, "type": "cuda"}
elif torch.backends.mps.is_available():

View File

@@ -1,4 +1,3 @@
import asyncio
import time
from collections.abc import Callable
from collections.abc import Generator
@@ -22,39 +21,21 @@ def simple_log_function_time(
include_args: bool = False,
) -> Callable[[F], F]:
def decorator(func: F) -> F:
if asyncio.iscoroutinefunction(func):
@wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
@wraps(func)
async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = await func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result
return result
return cast(F, wrapped_async_func)
else:
@wraps(func)
def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result
return cast(F, wrapped_sync_func)
return cast(F, wrapped_func)
return decorator

View File

@@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.54.1
litellm==1.53.1
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45

View File

@@ -1,34 +1,30 @@
black==23.3.0
boto3-stubs[s3]==1.34.133
celery-types==0.19.0
cohere==5.6.1
google-cloud-aiplatform==1.58.0
lxml==5.3.0
lxml_html_clean==0.2.2
mypy-extensions==1.0.0
mypy==1.8.0
pandas-stubs==2.2.3.241009
pandas==2.2.3
pre-commit==3.2.2
pytest-asyncio==0.22.0
pytest==7.4.4
reorder-python-imports==3.9.0
ruff==0.0.286
sentence-transformers==2.6.1
trafilatura==1.12.2
types-PyYAML==6.0.12.11
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
types-oauthlib==3.2.0.9
types-passlib==1.7.7.20240106
types-setuptools==68.0.0.3
types-Pillow==10.2.0.20240822
types-passlib==1.7.7.20240106
types-psutil==5.9.5.17
types-psycopg2==2.9.21.10
types-python-dateutil==2.8.19.13
types-pytz==2023.3.1.1
types-PyYAML==6.0.12.11
types-regex==2023.3.23.1
types-requests==2.28.11.17
types-retry==0.9.9.3
types-setuptools==68.0.0.3
types-urllib3==1.26.25.11
voyageai==0.2.3
trafilatura==1.12.2
lxml==5.3.0
lxml_html_clean==0.2.2
boto3-stubs[s3]==1.34.133
pandas==2.2.3
pandas-stubs==2.2.3.241009
cohere==5.6.1

View File

@@ -12,5 +12,5 @@ torch==2.2.0
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.54.1
litellm==1.50.2
sentry-sdk[fastapi,celery,starlette]==2.14.0

View File

@@ -69,10 +69,8 @@ class TenantManager:
return AllUsersResponse(
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]],
accepted_pages=data["accepted_pages"],
invited_pages=data["invited_pages"],
slack_users_pages=data["slack_users_pages"],
)
@staticmethod

View File

@@ -130,10 +130,8 @@ class UserManager:
all_users = AllUsersResponse(
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]],
accepted_pages=data["accepted_pages"],
invited_pages=data["invited_pages"],
slack_users_pages=data["slack_users_pages"],
)
for accepted_user in all_users.accepted:
if accepted_user.email == user.email and accepted_user.id == user.id:

View File

@@ -3,8 +3,6 @@ from datetime import datetime
from datetime import timezone
from typing import Any
import pytest
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.server.documents.models import DocumentSource
@@ -25,7 +23,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
@pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
# @pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
def test_slack_permission_sync(
reset: None,
vespa_client: vespa_fixture,

View File

@@ -27,6 +27,13 @@ def test_limited(reset: None) -> None:
)
assert response.status_code == 200
# test basic endpoints
response = requests.get(
f"{API_SERVER_URL}/input_prompt",
headers=api_key.headers,
)
assert response.status_code == 403
# test admin endpoints
response = requests.get(
f"{API_SERVER_URL}/admin/api-key",

View File

@@ -72,10 +72,8 @@ def process_text(
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
display_doc_order_dict=mock_doc_id_to_rank_map,
stop_stream=None,
)
result: list[DanswerAnswerPiece | CitationInfo] = []
for token in tokens:
result.extend(processor.process_token(token))
@@ -88,7 +86,6 @@ def process_text(
final_answer_text += piece.answer_piece or ""
elif isinstance(piece, CitationInfo):
citations.append(piece)
return final_answer_text, citations

View File

@@ -1,132 +0,0 @@
from datetime import datetime
import pytest
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.stream_processing.citation_processing import CitationProcessor
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.configs.constants import DocumentSource
"""
This module contains tests for the citation extraction functionality in Danswer,
specifically the substitution of the number of document cited in the UI. (The LLM
will see the sources post re-ranking and relevance check, the UI before these steps.)
This module is a derivative of test_citation_processing.py.
The tests focusses specifically on the substitution of the number of document cited in the UI.
Key components:
- mock_docs: A list of mock LlmDoc objects used for testing.
- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks.
- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check.
- process_text: A helper function that simulates the citation extraction process.
- test_citation_extraction: A parametrized test function covering various citation scenarios.
To add new test cases:
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
2. Each tuple should contain:
- A descriptive test name (string)
- Input tokens (list of strings)
- Expected output text (string)
- Expected citations (list of document IDs)
"""
mock_docs = [
LlmDoc(
document_id=f"doc_{int(id/2)}",
content="Document is a doc",
blurb=f"Document #{id}",
semantic_identifier=f"Doc {id}",
source_type=DocumentSource.WEB,
metadata={},
updated_at=datetime.now(),
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
match_highlights=[],
)
for id in range(10)
]
mock_doc_mapping = {
"doc_0": 1,
"doc_1": 2,
"doc_2": 3,
"doc_3": 4,
"doc_4": 5,
"doc_5": 6,
}
mock_doc_mapping_rerank = {
"doc_0": 2,
"doc_1": 1,
"doc_2": 4,
"doc_3": 3,
"doc_4": 6,
"doc_5": 5,
}
@pytest.fixture
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
return mock_docs, mock_doc_mapping
def process_text(
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
display_doc_order_dict=mock_doc_mapping_rerank,
stop_stream=None,
)
result: list[DanswerAnswerPiece | CitationInfo] = []
for token in tokens:
result.extend(processor.process_token(token))
result.extend(processor.process_token(None))
final_answer_text = ""
citations = []
for piece in result:
if isinstance(piece, DanswerAnswerPiece):
final_answer_text += piece.answer_piece or ""
elif isinstance(piece, CitationInfo):
citations.append(piece)
return final_answer_text, citations
@pytest.mark.parametrize(
"test_name, input_tokens, expected_text, expected_citations",
[
(
"Single citation",
["Gro", "wth! [", "1", "]", "."],
"Growth! [[2]](https://0.com).",
["doc_0"],
),
],
)
def test_citation_substitution(
mock_data: tuple[list[LlmDoc], dict[str, int]],
test_name: str,
input_tokens: list[str],
expected_text: str,
expected_citations: list[str],
) -> None:
final_answer_text, citations = process_text(input_tokens, mock_data)
assert (
final_answer_text.strip() == expected_text.strip()
), f"Test '{test_name}' failed: Final answer text does not match expected output."
assert [
citation.document_id for citation in citations
] == expected_citations, (
f"Test '{test_name}' failed: Citations do not match expected output."
)

View File

@@ -1,120 +0,0 @@
from typing import List
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
from danswer.connectors.models import Document
from danswer.connectors.models import DocumentSource
from danswer.connectors.models import Section
from danswer.indexing.indexing_pipeline import filter_documents
def create_test_document(
doc_id: str = "test_id",
title: str | None = "Test Title",
semantic_id: str = "test_semantic_id",
sections: List[Section] | None = None,
) -> Document:
if sections is None:
sections = [Section(text="Test content", link="test_link")]
return Document(
id=doc_id,
title=title,
semantic_identifier=semantic_id,
sections=sections,
source=DocumentSource.FILE,
metadata={},
)
def test_filter_documents_empty_title_and_content() -> None:
doc = create_test_document(
title="", semantic_id="", sections=[Section(text="", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 0
def test_filter_documents_empty_title_with_content() -> None:
doc = create_test_document(
title="", sections=[Section(text="Valid content", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 1
assert result[0].id == "test_id"
def test_filter_documents_empty_content_with_title() -> None:
doc = create_test_document(
title="Valid Title", sections=[Section(text="", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 1
assert result[0].id == "test_id"
def test_filter_documents_exceeding_max_chars() -> None:
if not MAX_DOCUMENT_CHARS: # Skip if no max chars configured
return
long_text = "a" * (MAX_DOCUMENT_CHARS + 1)
doc = create_test_document(sections=[Section(text=long_text, link="test_link")])
result = filter_documents([doc])
assert len(result) == 0
def test_filter_documents_valid_document() -> None:
doc = create_test_document(
title="Valid Title", sections=[Section(text="Valid content", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 1
assert result[0].id == "test_id"
assert result[0].title == "Valid Title"
def test_filter_documents_whitespace_only() -> None:
doc = create_test_document(
title=" ", semantic_id=" ", sections=[Section(text=" ", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 0
def test_filter_documents_semantic_id_no_title() -> None:
doc = create_test_document(
title=None,
semantic_id="Valid Semantic ID",
sections=[Section(text="Valid content", link="test_link")],
)
result = filter_documents([doc])
assert len(result) == 1
assert result[0].semantic_identifier == "Valid Semantic ID"
def test_filter_documents_multiple_sections() -> None:
doc = create_test_document(
sections=[
Section(text="Content 1", link="test_link"),
Section(text="Content 2", link="test_link"),
Section(text="Content 3", link="test_link"),
]
)
result = filter_documents([doc])
assert len(result) == 1
assert len(result[0].sections) == 3
def test_filter_documents_multiple_documents() -> None:
docs = [
create_test_document(doc_id="1", title="Title 1"),
create_test_document(
doc_id="2", title="", sections=[Section(text="", link="test_link")]
), # Should be filtered
create_test_document(doc_id="3", title="Title 3"),
]
result = filter_documents(docs)
assert len(result) == 2
assert {doc.id for doc in result} == {"1", "3"}
def test_filter_documents_empty_batch() -> None:
result = filter_documents([])
assert len(result) == 0

View File

@@ -1,198 +0,0 @@
import asyncio
import time
from collections.abc import AsyncGenerator
from typing import Any
from typing import List
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from httpx import AsyncClient
from litellm.exceptions import RateLimitError
from model_server.encoders import CloudEmbedding
from model_server.encoders import embed_text
from model_server.encoders import local_rerank
from model_server.encoders import process_embed_request
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbedRequest
@pytest.fixture
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]:
with patch("httpx.AsyncClient") as mock:
client = AsyncMock(spec=AsyncClient)
mock.return_value = client
client.post = AsyncMock()
async with client as c:
yield c
@pytest.fixture
def sample_embeddings() -> List[List[float]]:
return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@pytest.mark.asyncio
async def test_cloud_embedding_context_manager() -> None:
async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding:
assert not embedding._closed
assert embedding._closed
@pytest.mark.asyncio
async def test_cloud_embedding_explicit_close() -> None:
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
assert not embedding._closed
await embedding.aclose()
assert embedding._closed
@pytest.mark.asyncio
async def test_openai_embedding(
mock_http_client: AsyncMock, sample_embeddings: List[List[float]]
) -> None:
with patch("openai.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
result = await embedding._embed_openai(
["test1", "test2"], "text-embedding-ada-002"
)
assert result == sample_embeddings
mock_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_embed_text_cloud_provider() -> None:
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
result = await embed_text(
texts=["test1", "test2"],
text_type=EmbedTextType.QUERY,
model_name="fake-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key="fake-key",
provider_type=EmbeddingProvider.OPENAI,
prefix=None,
api_url=None,
api_version=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_embed.assert_called_once()
@pytest.mark.asyncio
async def test_embed_text_local_model() -> None:
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
mock_model = MagicMock()
mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_get_model.return_value = mock_model
result = await embed_text(
texts=["test1", "test2"],
text_type=EmbedTextType.QUERY,
model_name="fake-local-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key=None,
provider_type=None,
prefix=None,
api_url=None,
api_version=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_model.encode.assert_called_once()
@pytest.mark.asyncio
async def test_local_rerank() -> None:
with patch("model_server.encoders.get_local_reranking_model") as mock_get_model:
mock_model = MagicMock()
mock_array = MagicMock()
mock_array.tolist.return_value = [0.8, 0.6]
mock_model.predict.return_value = mock_array
mock_get_model.return_value = mock_model
result = await local_rerank(
query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model"
)
assert result == [0.8, 0.6]
mock_model.predict.assert_called_once()
@pytest.mark.asyncio
async def test_rate_limit_handling() -> None:
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
mock_embed.side_effect = RateLimitError(
"Rate limit exceeded", llm_provider="openai", model="fake-model"
)
with pytest.raises(RateLimitError):
await embed_text(
texts=["test"],
text_type=EmbedTextType.QUERY,
model_name="fake-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key="fake-key",
provider_type=EmbeddingProvider.OPENAI,
prefix=None,
api_url=None,
api_version=None,
)
@pytest.mark.asyncio
async def test_concurrent_embeddings() -> None:
def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]:
time.sleep(5)
return [[0.1, 0.2, 0.3]]
test_req = EmbedRequest(
texts=["test"],
model_name="'nomic-ai/nomic-embed-text-v1'",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key=None,
provider_type=None,
text_type=EmbedTextType.QUERY,
manual_query_prefix=None,
manual_passage_prefix=None,
api_url=None,
api_version=None,
)
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
mock_model = MagicMock()
mock_model.encode = mock_encode
mock_get_model.return_value = mock_model
start_time = time.time()
tasks = [process_embed_request(test_req) for _ in range(5)]
await asyncio.gather(*tasks)
end_time = time.time()
# 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread
# However, the developer may still introduce unnecessary blocking above the mock and this test will
# still pass as long as it's less than (7 - 5) / 5 seconds
assert end_time - start_time < 7

View File

@@ -6,7 +6,7 @@ chart-dirs:
# must be kept in sync with Chart.yaml
chart-repos:
- vespa=https://onyx-dot-app.github.io/vespa-helm-charts
- vespa=https://danswer-ai.github.io/vespa-helm-charts
- postgresql=https://charts.bitnami.com/bitnami
helm-extra-args: --debug --timeout 600s

View File

@@ -183,13 +183,6 @@ services:
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
- MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-}
- MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-}
# Egnyte OAuth Configs
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
# Celery Configs (defaults are set in the supervisord.conf file.
# prefer doing that to have one source of defaults)
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}

View File

@@ -3,13 +3,13 @@ dependencies:
repository: https://charts.bitnami.com/bitnami
version: 14.3.1
- name: vespa
repository: https://onyx-dot-app.github.io/vespa-helm-charts
version: 0.2.18
repository: https://danswer-ai.github.io/vespa-helm-charts
version: 0.2.16
- name: nginx
repository: oci://registry-1.docker.io/bitnamicharts
version: 15.14.0
- name: redis
repository: https://charts.bitnami.com/bitnami
version: 20.1.0
digest: sha256:5c9eb3d55d5f8e3beb64f26d26f686c8d62755daa10e2e6d87530bdf2fbbf957
generated: "2024-12-10T10:47:35.812483-08:00"
digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232
generated: "2024-11-07T09:39:30.17171-08:00"

View File

@@ -23,8 +23,8 @@ dependencies:
repository: https://charts.bitnami.com/bitnami
condition: postgresql.enabled
- name: vespa
version: 0.2.18
repository: https://onyx-dot-app.github.io/vespa-helm-charts
version: 0.2.16
repository: https://danswer-ai.github.io/vespa-helm-charts
condition: vespa.enabled
- name: nginx
version: 15.14.0

View File

@@ -61,8 +61,6 @@ data:
WEB_CONNECTOR_VALIDATE_URLS: ""
GONG_CONNECTOR_START_TIME: ""
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP: ""
MAX_DOCUMENT_CHARS: ""
MAX_FILE_SIZE_BYTES: ""
# DanswerBot SlackBot Configs
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: ""
DANSWER_BOT_DISPLAY_ERROR_MSGS: ""

View File

@@ -66,9 +66,6 @@ ARG NEXT_PUBLIC_POSTHOG_HOST
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
ARG NEXT_PUBLIC_CLOUD_ENABLED
ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED}
ARG NEXT_PUBLIC_SENTRY_DSN
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
@@ -141,9 +138,6 @@ ARG NEXT_PUBLIC_POSTHOG_HOST
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
ARG NEXT_PUBLIC_CLOUD_ENABLED
ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED}
ARG NEXT_PUBLIC_SENTRY_DSN
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 769 KiB

535
web/public/Wikipedia.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 164 KiB

View File

@@ -2,11 +2,6 @@ import CardSection from "@/components/admin/CardSection";
import { getNameFromPath } from "@/lib/fileUtils";
import { ValidSources } from "@/lib/types";
import Title from "@/components/ui/title";
import { EditIcon } from "@/components/icons/icons";
import { useState } from "react";
import { ChevronUpIcon } from "lucide-react";
import { ChevronDownIcon } from "@/components/icons/icons";
function convertObjectToString(obj: any): string | any {
// Check if obj is an object and not an array or null
@@ -44,83 +39,14 @@ function buildConfigEntries(
return obj;
}
function ConfigItem({ label, value }: { label: string; value: any }) {
const [isExpanded, setIsExpanded] = useState(false);
const isExpandable = Array.isArray(value) && value.length > 5;
const renderValue = () => {
if (Array.isArray(value)) {
const displayedItems = isExpanded ? value : value.slice(0, 5);
return (
<ul className="list-disc max-w-full pl-4 mt-2 overflow-x-auto">
{displayedItems.map((item, index) => (
<li
key={index}
className="mb-1 max-w-full overflow-hidden text-right text-ellipsis whitespace-nowrap"
>
{convertObjectToString(item)}
</li>
))}
</ul>
);
} else if (typeof value === "object" && value !== null) {
return (
<div className="mt-2 overflow-x-auto">
{Object.entries(value).map(([key, val]) => (
<div key={key} className="mb-1">
<span className="font-semibold">{key}:</span>{" "}
{convertObjectToString(val)}
</div>
))}
</div>
);
}
return convertObjectToString(value) || "-";
};
return (
<li className="w-full py-2">
<div className="flex items-center justify-between w-full">
<span className="mb-2">{label}</span>
<div className="mt-2 overflow-x-auto w-fit">
{renderValue()}
{isExpandable && (
<button
onClick={() => setIsExpanded(!isExpanded)}
className="mt-2 ml-auto text-text-600 hover:text-text-800 flex items-center"
>
{isExpanded ? (
<>
<ChevronUpIcon className="h-4 w-4 mr-1" />
Show less
</>
) : (
<>
<ChevronDownIcon className="h-4 w-4 mr-1" />
Show all ({value.length} items)
</>
)}
</button>
)}
</div>
</div>
</li>
);
}
export function AdvancedConfigDisplay({
pruneFreq,
refreshFreq,
indexingStart,
onRefreshEdit,
onPruningEdit,
}: {
pruneFreq: number | null;
refreshFreq: number | null;
indexingStart: Date | null;
onRefreshEdit: () => void;
onPruningEdit: () => void;
}) {
const formatRefreshFrequency = (seconds: number | null): string => {
if (seconds === null) return "-";
@@ -149,21 +75,14 @@ export function AdvancedConfigDisplay({
<>
<Title className="mt-8 mb-2">Advanced Configuration</Title>
<CardSection>
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
{pruneFreq && (
<li
key={0}
className="w-full flex justify-between items-center py-2"
>
<span>Pruning Frequency</span>
<span className="ml-auto w-24">
{formatPruneFrequency(pruneFreq)}
</span>
<span className="w-8 text-right">
<button onClick={() => onPruningEdit()}>
<EditIcon size={12} />
</button>
</span>
<span>{formatPruneFrequency(pruneFreq)}</span>
</li>
)}
{refreshFreq && (
@@ -172,14 +91,7 @@ export function AdvancedConfigDisplay({
className="w-full flex justify-between items-center py-2"
>
<span>Refresh Frequency</span>
<span className="ml-auto w-24">
{formatRefreshFrequency(refreshFreq)}
</span>
<span className="w-8 text-right">
<button onClick={() => onRefreshEdit()}>
<EditIcon size={12} />
</button>
</span>
<span>{formatRefreshFrequency(refreshFreq)}</span>
</li>
)}
{indexingStart && (
@@ -215,9 +127,15 @@ export function ConfigDisplay({
<>
<Title className="mb-2">Configuration</Title>
<CardSection>
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
{configEntries.map(([key, value]) => (
<ConfigItem key={key} label={key} value={value} />
<li
key={key}
className="w-full flex justify-between items-center py-2"
>
<span>{key}</span>
<span>{convertObjectToString(value) || "-"}</span>
</li>
))}
</ul>
</CardSection>

View File

@@ -7,10 +7,7 @@ import { SourceIcon } from "@/components/SourceIcon";
import { CCPairStatus } from "@/components/Status";
import { usePopup } from "@/components/admin/connectors/Popup";
import CredentialSection from "@/components/credentials/CredentialSection";
import {
updateConnectorCredentialPairName,
updateConnectorCredentialPairProperty,
} from "@/lib/connector";
import { updateConnectorCredentialPairName } from "@/lib/connector";
import { credentialTemplates } from "@/lib/connectors/credentials";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ValidSources } from "@/lib/types";
@@ -29,33 +26,12 @@ import { buildCCPairInfoUrl } from "./lib";
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
import { Button } from "@/components/ui/button";
import EditPropertyModal from "@/components/modals/EditPropertyModal";
import * as Yup from "yup";
// since the uploaded files are cleaned up after some period of time
// re-indexing will not work for the file connector. Also, it would not
// make sense to re-index, since the files will not have changed.
const CONNECTOR_TYPES_THAT_CANT_REINDEX: ValidSources[] = [ValidSources.File];
// synchronize these validations with the SQLAlchemy connector class until we have a
// centralized schema for both frontend and backend
const RefreshFrequencySchema = Yup.object().shape({
propertyValue: Yup.number()
.typeError("Property value must be a valid number")
.integer("Property value must be an integer")
.min(60, "Property value must be greater than or equal to 60")
.required("Property value is required"),
});
const PruneFrequencySchema = Yup.object().shape({
propertyValue: Yup.number()
.typeError("Property value must be a valid number")
.integer("Property value must be an integer")
.min(86400, "Property value must be greater than or equal to 86400")
.required("Property value is required"),
});
function Main({ ccPairId }: { ccPairId: number }) {
const router = useRouter(); // Initialize the router
const {
@@ -69,8 +45,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
);
const [hasLoadedOnce, setHasLoadedOnce] = useState(false);
const [editingRefreshFrequency, setEditingRefreshFrequency] = useState(false);
const [editingPruningFrequency, setEditingPruningFrequency] = useState(false);
const { popup, setPopup } = usePopup();
const finishConnectorDeletion = useCallback(() => {
@@ -116,86 +90,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
}
};
const handleRefreshEdit = async () => {
setEditingRefreshFrequency(true);
};
const handlePruningEdit = async () => {
setEditingPruningFrequency(true);
};
const handleRefreshSubmit = async (
propertyName: string,
propertyValue: string
) => {
const parsedRefreshFreq = parseInt(propertyValue, 10);
if (isNaN(parsedRefreshFreq)) {
setPopup({
message: "Invalid refresh frequency: must be an integer",
type: "error",
});
return;
}
try {
const response = await updateConnectorCredentialPairProperty(
ccPairId,
propertyName,
String(parsedRefreshFreq)
);
if (!response.ok) {
throw new Error(await response.text());
}
mutate(buildCCPairInfoUrl(ccPairId));
setPopup({
message: "Connector refresh frequency updated successfully",
type: "success",
});
} catch (error) {
setPopup({
message: "Failed to update connector refresh frequency",
type: "error",
});
}
};
const handlePruningSubmit = async (
propertyName: string,
propertyValue: string
) => {
const parsedFreq = parseInt(propertyValue, 10);
if (isNaN(parsedFreq)) {
setPopup({
message: "Invalid pruning frequency: must be an integer",
type: "error",
});
return;
}
try {
const response = await updateConnectorCredentialPairProperty(
ccPairId,
propertyName,
String(parsedFreq)
);
if (!response.ok) {
throw new Error(await response.text());
}
mutate(buildCCPairInfoUrl(ccPairId));
setPopup({
message: "Connector pruning frequency updated successfully",
type: "success",
});
} catch (error) {
setPopup({
message: "Failed to update connector pruning frequency",
type: "error",
});
}
};
if (isLoading) {
return <ThreeDotsLoader />;
}
@@ -220,35 +114,9 @@ function Main({ ccPairId }: { ccPairId: number }) {
refresh_freq: refreshFreq,
indexing_start: indexingStart,
} = ccPair.connector;
return (
<>
{popup}
{editingRefreshFrequency && (
<EditPropertyModal
propertyTitle="Refresh Frequency"
propertyDetails="How often the connector should refresh (in seconds)"
propertyName="refresh_frequency"
propertyValue={String(refreshFreq)}
validationSchema={RefreshFrequencySchema}
onSubmit={handleRefreshSubmit}
onClose={() => setEditingRefreshFrequency(false)}
/>
)}
{editingPruningFrequency && (
<EditPropertyModal
propertyTitle="Pruning Frequency"
propertyDetails="How often the connector should be pruned (in seconds)"
propertyName="pruning_frequency"
propertyValue={String(pruneFreq)}
validationSchema={PruneFrequencySchema}
onSubmit={handlePruningSubmit}
onClose={() => setEditingPruningFrequency(false)}
/>
)}
<BackButton
behaviorOverride={() => router.push("/admin/indexing/status")}
/>
@@ -257,7 +125,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
<SourceIcon iconSize={32} sourceType={ccPair.connector.source} />
</div>
<div className="ml-1 overflow-hidden text-ellipsis whitespace-nowrap flex-1 mr-4">
<div className="ml-1">
<EditableStringFieldDisplay
value={ccPair.name}
isEditable={ccPair.is_editable_for_current_user}
@@ -345,8 +213,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
pruneFreq={pruneFreq}
indexingStart={indexingStart}
refreshFreq={refreshFreq}
onRefreshEdit={handleRefreshEdit}
onPruningEdit={handlePruningEdit}
/>
)}

View File

@@ -49,8 +49,6 @@ import { useRouter } from "next/navigation";
import CardSection from "@/components/admin/CardSection";
import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils";
import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
export interface AdvancedConfig {
refreshFreq: number;
pruneFreq: number;
@@ -162,7 +160,6 @@ export default function AddConnector({
// Form context and popup management
const { setFormStep, setAllowCreate, formStep } = useFormContext();
const { popup, setPopup } = usePopup();
const [uploading, setUploading] = useState(false);
// Hooks for Google Drive and Gmail credentials
const { liveGDriveCredential } = useGoogleDriveCredentials(connector);
@@ -340,24 +337,16 @@ export default function AddConnector({
}
// File-specific handling
if (connector == "file") {
setUploading(true);
try {
const response = await submitFiles(
selectedFiles,
setPopup,
name,
access_type,
groups
);
if (response) {
onSuccess();
}
} catch (error) {
setPopup({ message: "Error uploading files", type: "error" });
} finally {
setUploading(false);
const response = await submitFiles(
selectedFiles,
setPopup,
name,
access_type,
groups
);
if (response) {
onSuccess();
}
return;
}
@@ -419,9 +408,9 @@ export default function AddConnector({
<div className="mx-auto mb-8 w-full">
{popup}
{uploading && (
<TemporaryLoadingModal content="Uploading files..." />
)}
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle
includeDivider={false}
@@ -453,19 +442,11 @@ export default function AddConnector({
{/* Button to pop up a form to manually enter credentials */}
<button
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
onClick={async () => {
const redirectUrl =
await getConnectorOauthRedirectUrl(connector);
// if redirect is supported, just use it
if (redirectUrl) {
window.location.href = redirectUrl;
} else {
setCreateConnectorToggle(
(createConnectorToggle) =>
!createConnectorToggle
);
}
}}
onClick={() =>
setCreateConnectorToggle(
(createConnectorToggle) => !createConnectorToggle
)
}
>
Create New
</button>

View File

@@ -104,9 +104,7 @@ const GDriveMain = ({}: {}) => {
const googleDriveServiceAccountCredential:
| Credential<GoogleDriveServiceAccountCredentialJson>
| undefined = credentialsData.find(
(credential) =>
credential.credential_json?.google_service_account_key &&
credential.source === "google_drive"
(credential) => credential.credential_json?.google_service_account_key
);
const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus<

View File

@@ -135,7 +135,7 @@ export const DocumentFeedbackTable = ({
/>
</TableCell>
<TableCell>
<div className="relative">
<div className="ml-auto flex w-16">
<div
key={document.document_id}
className="h-10 ml-auto mr-8"

View File

@@ -0,0 +1,46 @@
import useSWR from "swr";
import { InputPrompt } from "./interfaces";
const fetcher = (url: string) => fetch(url).then((res) => res.json());
export const useAdminInputPrompts = () => {
const { data, error, mutate } = useSWR<InputPrompt[]>(
`/api/admin/input_prompt`,
fetcher
);
return {
data,
error,
isLoading: !error && !data,
refreshInputPrompts: mutate,
};
};
export const useInputPrompts = (includePublic: boolean = false) => {
const { data, error, mutate } = useSWR<InputPrompt[]>(
`/api/input_prompt${includePublic ? "?include_public=true" : ""}`,
fetcher
);
return {
data,
error,
isLoading: !error && !data,
refreshInputPrompts: mutate,
};
};
export const useInputPrompt = (id: number) => {
const { data, error, mutate } = useSWR<InputPrompt>(
`/api/input_prompt/${id}`,
fetcher
);
return {
data,
error,
isLoading: !error && !data,
refreshInputPrompt: mutate,
};
};

View File

@@ -0,0 +1,31 @@
export interface InputPrompt {
id: number;
prompt: string;
content: string;
active: boolean;
is_public: string;
}
export interface EditPromptModalProps {
onClose: () => void;
promptId: number;
editInputPrompt: (
promptId: number,
values: CreateInputPromptRequest
) => Promise<void>;
}
export interface CreateInputPromptRequest {
prompt: string;
content: string;
}
export interface AddPromptModalProps {
onClose: () => void;
onSubmit: (promptData: CreateInputPromptRequest) => void;
}
export interface PromptData {
id: number;
prompt: string;
content: string;
}

View File

@@ -0,0 +1,69 @@
import React from "react";
import { Formik, Form } from "formik";
import * as Yup from "yup";
import { Button } from "@/components/ui/button";
import { BookstackIcon } from "@/components/icons/icons";
import { AddPromptModalProps } from "../interfaces";
import { TextFormField } from "@/components/admin/connectors/Field";
import { Modal } from "@/components/Modal";
const AddPromptSchema = Yup.object().shape({
title: Yup.string().required("Title is required"),
prompt: Yup.string().required("Prompt is required"),
});
const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => {
return (
<Modal onOutsideClick={onClose} width="w-full max-w-3xl">
<Formik
initialValues={{
title: "",
prompt: "",
}}
validationSchema={AddPromptSchema}
onSubmit={(values, { setSubmitting }) => {
onSubmit({
prompt: values.title,
content: values.prompt,
});
setSubmitting(false);
onClose();
}}
>
{({ isSubmitting, setFieldValue }) => (
<Form>
<h2 className="w-full text-2xl gap-x-2 text-emphasis font-bold mb-3 flex items-center">
<BookstackIcon size={20} />
Add prompt
</h2>
<TextFormField
label="Title"
name="title"
placeholder="Title (e.g. 'Reword')"
/>
<TextFormField
isTextArea
label="Prompt"
name="prompt"
placeholder="Enter a prompt (e.g. 'help me rewrite the following politely and concisely for professional communication')"
/>
<Button
type="submit"
className="w-full"
disabled={isSubmitting}
variant="submit"
>
Add prompt
</Button>
</Form>
)}
</Formik>
</Modal>
);
};
export default AddPromptModal;

View File

@@ -0,0 +1,138 @@
import React from "react";
import { Formik, Form, Field, ErrorMessage } from "formik";
import * as Yup from "yup";
import { Modal } from "@/components/Modal";
import { Textarea } from "@/components/ui/textarea";
import { Button } from "@/components/ui/button";
import { useInputPrompt } from "../hooks";
import { EditPromptModalProps } from "../interfaces";
const EditPromptSchema = Yup.object().shape({
prompt: Yup.string().required("Title is required"),
content: Yup.string().required("Content is required"),
active: Yup.boolean(),
});
const EditPromptModal = ({
onClose,
promptId,
editInputPrompt,
}: EditPromptModalProps) => {
const {
data: promptData,
error,
refreshInputPrompt,
} = useInputPrompt(promptId);
if (error)
return (
<Modal onOutsideClick={onClose} width="max-w-xl">
<p>Failed to load prompt data</p>
</Modal>
);
if (!promptData)
return (
<Modal onOutsideClick={onClose} width="w-full max-w-xl">
<p>Loading...</p>
</Modal>
);
return (
<Modal onOutsideClick={onClose} width="w-full max-w-xl">
<Formik
initialValues={{
prompt: promptData.prompt,
content: promptData.content,
active: promptData.active,
}}
validationSchema={EditPromptSchema}
onSubmit={(values) => {
editInputPrompt(promptId, values);
refreshInputPrompt();
}}
>
{({ isSubmitting, values }) => (
<Form className="items-stretch">
<h2 className="text-2xl text-emphasis font-bold mb-3 flex items-center">
<svg
className="w-6 h-6 mr-2"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M3 17.25V21h3.75L17.81 9.94l-3.75-3.75L3 17.25zM20.71 7.04c.39-.39.39-1.02 0-1.41l-2.34-2.34c-.39-.39-1.02-.39-1.41 0l-1.83 1.83 3.75 3.75 1.83-1.83z" />
</svg>
Edit prompt
</h2>
<div className="space-y-4">
<div>
<label
htmlFor="prompt"
className="block text-sm font-medium mb-1"
>
Title
</label>
<Field
as={Textarea}
id="prompt"
name="prompt"
placeholder="Title (e.g. 'Draft email')"
/>
<ErrorMessage
name="prompt"
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
<div>
<label
htmlFor="content"
className="block text-sm font-medium mb-1"
>
Content
</label>
<Field
as={Textarea}
id="content"
name="content"
placeholder="Enter prompt content (e.g. 'Write a professional-sounding email about the following content')"
rows={4}
/>
<ErrorMessage
name="content"
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
<div>
<label className="flex items-center">
<Field type="checkbox" name="active" className="mr-2" />
Active prompt
</label>
</div>
</div>
<div className="mt-6">
<Button
type="submit"
disabled={
isSubmitting ||
(values.prompt === promptData.prompt &&
values.content === promptData.content &&
values.active === promptData.active)
}
>
{isSubmitting ? "Updating..." : "Update prompt"}
</Button>
</div>
</Form>
)}
</Formik>
</Modal>
);
};
export default EditPromptModal;

View File

@@ -0,0 +1,32 @@
"use client";
import { AdminPageTitle } from "@/components/admin/Title";
import { ClosedBookIcon } from "@/components/icons/icons";
import { useAdminInputPrompts } from "./hooks";
import { PromptSection } from "./promptSection";
const Page = () => {
const {
data: promptLibrary,
error: promptLibraryError,
isLoading: promptLibraryIsLoading,
refreshInputPrompts: refreshPrompts,
} = useAdminInputPrompts();
return (
<div className="container mx-auto">
<AdminPageTitle
icon={<ClosedBookIcon size={32} />}
title="Prompt Library"
/>
<PromptSection
promptLibrary={promptLibrary || []}
isLoading={promptLibraryIsLoading}
error={promptLibraryError}
refreshPrompts={refreshPrompts}
isPublic={true}
/>
</div>
);
};
export default Page;

View File

@@ -0,0 +1,249 @@
"use client";
import { EditIcon, TrashIcon } from "@/components/icons/icons";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { MagnifyingGlass } from "@phosphor-icons/react";
import { useState } from "react";
import {
Table,
TableHead,
TableRow,
TableBody,
TableCell,
} from "@/components/ui/table";
import { FilterDropdown } from "@/components/search/filtering/FilterDropdown";
import { FiTag } from "react-icons/fi";
import { PageSelector } from "@/components/PageSelector";
import { InputPrompt } from "./interfaces";
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
import { TableHeader } from "@/components/ui/table";
const CategoryBubble = ({
name,
onDelete,
}: {
name: string;
onDelete?: () => void;
}) => (
<span
className={`
inline-block
px-2
py-1
mr-1
mb-1
text-xs
font-semibold
text-emphasis
bg-hover
rounded-full
items-center
w-fit
${onDelete ? "cursor-pointer" : ""}
`}
onClick={onDelete}
>
{name}
{onDelete && (
<button
className="ml-1 text-subtle hover:text-emphasis"
aria-label="Remove category"
>
&times;
</button>
)}
</span>
);
const NUM_RESULTS_PER_PAGE = 10;
export const PromptLibraryTable = ({
promptLibrary,
refresh,
setPopup,
handleEdit,
isPublic,
}: {
promptLibrary: InputPrompt[];
refresh: () => void;
setPopup: (popup: PopupSpec | null) => void;
handleEdit: (promptId: number) => void;
isPublic: boolean;
}) => {
const [query, setQuery] = useState("");
const [currentPage, setCurrentPage] = useState(1);
const [selectedStatus, setSelectedStatus] = useState<string[]>([]);
const columns = [
{ name: "Prompt", key: "prompt" },
{ name: "Content", key: "content" },
{ name: "Status", key: "status" },
{ name: "", key: "edit" },
{ name: "", key: "delete" },
];
const filteredPromptLibrary = promptLibrary.filter((item) => {
const cleanedQuery = query.toLowerCase();
const searchMatch =
item.prompt.toLowerCase().includes(cleanedQuery) ||
item.content.toLowerCase().includes(cleanedQuery);
const statusMatch =
selectedStatus.length === 0 ||
(selectedStatus.includes("Active") && item.active) ||
(selectedStatus.includes("Inactive") && !item.active);
return searchMatch && statusMatch;
});
const totalPages = Math.ceil(
filteredPromptLibrary.length / NUM_RESULTS_PER_PAGE
);
const startIndex = (currentPage - 1) * NUM_RESULTS_PER_PAGE;
const endIndex = startIndex + NUM_RESULTS_PER_PAGE;
const paginatedPromptLibrary = filteredPromptLibrary.slice(
startIndex,
endIndex
);
const handlePageChange = (page: number) => {
setCurrentPage(page);
};
const handleDelete = async (id: number) => {
const response = await fetch(
`/api${isPublic ? "/admin" : ""}/input_prompt/${id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
setPopup({ message: "Failed to delete input prompt", type: "error" });
}
refresh();
setConfirmDeletionId(null);
};
const handleStatusSelect = (status: string) => {
setSelectedStatus((prev) => {
if (prev.includes(status)) {
return prev.filter((s) => s !== status);
}
return [...prev, status];
});
};
const [confirmDeletionId, setConfirmDeletionId] = useState<number | null>(
null
);
return (
<div className="justify-center py-2">
{confirmDeletionId != null && (
<DeleteEntityModal
onClose={() => setConfirmDeletionId(null)}
onSubmit={() => handleDelete(confirmDeletionId)}
entityType="prompt"
entityName={
paginatedPromptLibrary.find(
(prompt) => prompt.id === confirmDeletionId
)?.prompt ?? ""
}
/>
)}
<div className="flex items-center w-full border-2 border-border rounded-lg px-4 py-2 focus-within:border-accent">
<MagnifyingGlass />
<input
className="flex-grow ml-2 bg-transparent outline-none placeholder-subtle"
placeholder="Find prompts..."
value={query}
onChange={(event) => {
setQuery(event.target.value);
setCurrentPage(1);
}}
/>
</div>
<div className="my-4 border-b border-border">
<FilterDropdown
options={[
{ key: "Active", display: "Active" },
{ key: "Inactive", display: "Inactive" },
]}
selected={selectedStatus}
handleSelect={(option) => handleStatusSelect(option.key)}
icon={<FiTag size={16} />}
defaultDisplay="All Statuses"
/>
<div className="flex flex-col items-stretch w-full flex-wrap pb-4 mt-3">
{selectedStatus.map((status) => (
<CategoryBubble
key={status}
name={status}
onDelete={() => handleStatusSelect(status)}
/>
))}
</div>
</div>
<div className="mx-auto overflow-x-auto">
<Table>
<TableHeader>
<TableRow>
{columns.map((column) => (
<TableHead key={column.key}>{column.name}</TableHead>
))}
</TableRow>
</TableHeader>
<TableBody>
{paginatedPromptLibrary.length > 0 ? (
paginatedPromptLibrary
.filter((prompt) => !(!isPublic && prompt.is_public))
.map((item) => (
<TableRow key={item.id}>
<TableCell>{item.prompt}</TableCell>
<TableCell
className="
max-w-xs
overflow-hidden
text-ellipsis
break-words
"
>
{item.content}
</TableCell>
<TableCell>{item.active ? "Active" : "Inactive"}</TableCell>
<TableCell>
<button
className="cursor-pointer"
onClick={() => setConfirmDeletionId(item.id)}
>
<TrashIcon size={20} />
</button>
</TableCell>
<TableCell>
<button onClick={() => handleEdit(item.id)}>
<EditIcon size={12} />
</button>
</TableCell>
</TableRow>
))
) : (
<TableRow>
<TableCell colSpan={6}>No matching prompts found...</TableCell>
</TableRow>
)}
</TableBody>
</Table>
{paginatedPromptLibrary.length > 0 && (
<div className="mt-4 flex justify-center">
<PageSelector
currentPage={currentPage}
totalPages={totalPages}
onPageChange={handlePageChange}
shouldScroll={true}
/>
</div>
)}
</div>
</div>
);
};

View File

@@ -0,0 +1,150 @@
"use client";
import { usePopup } from "@/components/admin/connectors/Popup";
import { ThreeDotsLoader } from "@/components/Loading";
import { ErrorCallout } from "@/components/ErrorCallout";
import { Button } from "@/components/ui/button";
import { Separator } from "@/components/ui/separator";
import Text from "@/components/ui/text";
import { useState } from "react";
import AddPromptModal from "./modals/AddPromptModal";
import EditPromptModal from "./modals/EditPromptModal";
import { PromptLibraryTable } from "./promptLibrary";
import { CreateInputPromptRequest, InputPrompt } from "./interfaces";
export const PromptSection = ({
promptLibrary,
isLoading,
error,
refreshPrompts,
centering = false,
isPublic,
}: {
promptLibrary: InputPrompt[];
isLoading: boolean;
error: any;
refreshPrompts: () => void;
centering?: boolean;
isPublic: boolean;
}) => {
const { popup, setPopup } = usePopup();
const [newPrompt, setNewPrompt] = useState(false);
const [newPromptId, setNewPromptId] = useState<number | null>(null);
const createInputPrompt = async (
promptData: CreateInputPromptRequest
): Promise<InputPrompt> => {
const response = await fetch("/api/input_prompt", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ ...promptData, is_public: isPublic }),
});
if (!response.ok) {
setPopup({ message: "Failed to create input prompt", type: "error" });
}
refreshPrompts();
return response.json();
};
const editInputPrompt = async (
promptId: number,
values: CreateInputPromptRequest
) => {
try {
const response = await fetch(`/api/input_prompt/${promptId}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(values),
});
if (!response.ok) {
setPopup({ message: "Failed to update prompt!", type: "error" });
}
setNewPromptId(null);
refreshPrompts();
} catch (err) {
setPopup({ message: `Failed to update prompt: ${err}`, type: "error" });
}
};
if (isLoading) {
return <ThreeDotsLoader />;
}
if (error || !promptLibrary) {
return (
<ErrorCallout
errorTitle="Error loading standard answers"
errorMsg={error?.info?.message || error?.message?.info?.detail}
/>
);
}
const handleEdit = (promptId: number) => {
setNewPromptId(promptId);
};
return (
<div
className={`w-full ${
centering ? "flex-col flex justify-center" : ""
} mb-8`}
>
{popup}
{newPrompt && (
<AddPromptModal
onSubmit={createInputPrompt}
onClose={() => setNewPrompt(false)}
/>
)}
{newPromptId && (
<EditPromptModal
promptId={newPromptId}
editInputPrompt={editInputPrompt}
onClose={() => setNewPromptId(null)}
/>
)}
<div className={centering ? "max-w-sm mx-auto" : ""}>
<Text className="mb-2 my-auto">
Create prompts that can be accessed with the <i>`/`</i> shortcut in
Danswer Chat.{" "}
{isPublic
? "Prompts created here will be accessible to all users."
: "Prompts created here will be available only to you."}
</Text>
</div>
<div className="mb-2"></div>
<Button
onClick={() => setNewPrompt(true)}
className={centering ? "mx-auto" : ""}
variant="navigate"
size="sm"
>
New Prompt
</Button>
<Separator />
<div>
<PromptLibraryTable
isPublic={isPublic}
promptLibrary={promptLibrary}
setPopup={setPopup}
refresh={refreshPrompts}
handleEdit={handleEdit}
/>
</div>
</div>
);
};

View File

@@ -1,13 +1,13 @@
"use client";
import { useEffect, useState } from "react";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Button } from "@/components/ui/button";
import InvitedUserTable from "@/components/admin/users/InvitedUserTable";
import SignedUpUserTable from "@/components/admin/users/SignedUpUserTable";
import { SearchBar } from "@/components/search/SearchBar";
import { useState } from "react";
import { FiPlusSquare } from "react-icons/fi";
import { Modal } from "@/components/Modal";
import { Button } from "@/components/ui/button";
import Text from "@/components/ui/text";
import { LoadingAnimation } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import { usePopup, PopupSpec } from "@/components/admin/connectors/Popup";
@@ -15,10 +15,42 @@ import { UsersIcon } from "@/components/icons/icons";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr";
import { ErrorCallout } from "@/components/ErrorCallout";
import { HidableSection } from "@/app/admin/assistants/HidableSection";
import BulkAdd from "@/components/admin/users/BulkAdd";
import { UsersResponse } from "@/lib/users/interfaces";
import SlackUserTable from "@/components/admin/users/SlackUserTable";
import Text from "@/components/ui/text";
const ValidDomainsDisplay = ({ validDomains }: { validDomains: string[] }) => {
if (!validDomains.length) {
return (
<div className="text-sm">
No invited users. Anyone can sign up with a valid email address. To
restrict access you can:
<div className="flex flex-wrap ml-2 mt-1">
(1) Invite users above. Once a user has been invited, only emails that
have explicitly been invited will be able to sign-up.
</div>
<div className="mt-1 ml-2">
(2) Set the{" "}
<b className="font-mono w-fit h-fit">VALID_EMAIL_DOMAINS</b>{" "}
environment variable to a comma separated list of email domains. This
will restrict access to users with email addresses from these domains.
</div>
</div>
);
}
return (
<div className="text-sm">
No invited users. Anyone with an email address with any of the following
domains can sign up: <i>{validDomains.join(", ")}</i>.
<div className="mt-2">
To further restrict access you can invite users above. Once a user has
been invited, only emails that have explicitly been invited will be able
to sign-up.
</div>
</div>
);
};
const UsersTables = ({
q,
@@ -29,48 +61,23 @@ const UsersTables = ({
}) => {
const [invitedPage, setInvitedPage] = useState(1);
const [acceptedPage, setAcceptedPage] = useState(1);
const [slackUsersPage, setSlackUsersPage] = useState(1);
const [usersData, setUsersData] = useState<UsersResponse | undefined>(
undefined
);
const [domainsData, setDomainsData] = useState<string[] | undefined>(
undefined
);
const { data, error, mutate } = useSWR<UsersResponse>(
`/api/manage/users?q=${encodeURIComponent(q)}&accepted_page=${
const { data, isLoading, mutate, error } = useSWR<UsersResponse>(
`/api/manage/users?q=${encodeURI(q)}&accepted_page=${
acceptedPage - 1
}&invited_page=${invitedPage - 1}&slack_users_page=${slackUsersPage - 1}`,
}&invited_page=${invitedPage - 1}`,
errorHandlingFetcher
);
const {
data: validDomains,
isLoading: isLoadingDomains,
error: domainsError,
} = useSWR<string[]>("/api/manage/admin/valid-domains", errorHandlingFetcher);
const { data: validDomains, error: domainsError } = useSWR<string[]>(
"/api/manage/admin/valid-domains",
errorHandlingFetcher
);
useEffect(() => {
if (data) {
setUsersData(data);
}
}, [data]);
useEffect(() => {
if (validDomains) {
setDomainsData(validDomains);
}
}, [validDomains]);
const activeData = data ?? usersData;
const activeDomains = validDomains ?? domainsData;
// Show loading animation only during the initial data fetch
if (!activeData || !activeDomains) {
if (isLoading || isLoadingDomains) {
return <LoadingAnimation text="Loading" />;
}
if (error) {
if (error || !data) {
return (
<ErrorCallout
errorTitle="Error loading users"
@@ -79,7 +86,7 @@ const UsersTables = ({
);
}
if (domainsError) {
if (domainsError || !validDomains) {
return (
<ErrorCallout
errorTitle="Error loading valid domains"
@@ -88,94 +95,45 @@ const UsersTables = ({
);
}
const {
accepted,
invited,
accepted_pages,
invited_pages,
slack_users,
slack_users_pages,
} = activeData;
const { accepted, invited, accepted_pages, invited_pages } = data;
// remove users that are already accepted
const finalInvited = invited.filter(
(user) => !accepted.some((u) => u.email === user.email)
(user) => !accepted.map((u) => u.email).includes(user.email)
);
return (
<Tabs defaultValue="invited">
<TabsList>
<TabsTrigger value="invited">Invited Users</TabsTrigger>
<TabsTrigger value="current">Current Users</TabsTrigger>
<TabsTrigger value="danswerbot">DanswerBot Users</TabsTrigger>
</TabsList>
<TabsContent value="invited">
<Card>
<CardHeader>
<CardTitle>Invited Users</CardTitle>
</CardHeader>
<CardContent>
{finalInvited.length > 0 ? (
<InvitedUserTable
users={finalInvited}
setPopup={setPopup}
currentPage={invitedPage}
onPageChange={setInvitedPage}
totalPages={invited_pages}
mutate={mutate}
/>
) : (
<p>Users that have been invited will show up here</p>
)}
</CardContent>
</Card>
</TabsContent>
<TabsContent value="current">
<Card>
<CardHeader>
<CardTitle>Current Users</CardTitle>
</CardHeader>
<CardContent>
{accepted.length > 0 ? (
<SignedUpUserTable
users={accepted}
setPopup={setPopup}
currentPage={acceptedPage}
onPageChange={setAcceptedPage}
totalPages={accepted_pages}
mutate={mutate}
/>
) : (
<p>Users that have an account will show up here</p>
)}
</CardContent>
</Card>
</TabsContent>
<TabsContent value="danswerbot">
<Card>
<CardHeader>
<CardTitle>DanswerBot Users</CardTitle>
</CardHeader>
<CardContent>
{slack_users.length > 0 ? (
<SlackUserTable
setPopup={setPopup}
currentPage={slackUsersPage}
onPageChange={setSlackUsersPage}
totalPages={slack_users_pages}
invitedUsers={finalInvited}
slackusers={slack_users}
mutate={mutate}
/>
) : (
<p>Slack-only users will show up here</p>
)}
</CardContent>
</Card>
</TabsContent>
</Tabs>
<>
<HidableSection sectionTitle="Invited Users">
{invited.length > 0 ? (
finalInvited.length > 0 ? (
<InvitedUserTable
users={finalInvited}
setPopup={setPopup}
currentPage={invitedPage}
onPageChange={setInvitedPage}
totalPages={invited_pages}
mutate={mutate}
/>
) : (
<div className="text-sm">
To invite additional teammates, use the <b>Invite Users</b> button
above!
</div>
)
) : (
<ValidDomainsDisplay validDomains={validDomains} />
)}
</HidableSection>
<SignedUpUserTable
users={accepted}
setPopup={setPopup}
currentPage={acceptedPage}
onPageChange={setAcceptedPage}
totalPages={accepted_pages}
mutate={mutate}
/>
</>
);
};
@@ -257,7 +215,6 @@ const Page = () => {
return (
<div className="mx-auto container">
<AdminPageTitle title="Manage Users" icon={<UsersIcon size={32} />} />
<SearchableTables />
</div>
);

View File

@@ -0,0 +1,51 @@
"use client";
import SidebarWrapper from "../SidebarWrapper";
import { ChatSession } from "@/app/chat/interfaces";
import { Folder } from "@/app/chat/folders/interfaces";
import { User } from "@/lib/types";
import { AssistantsPageTitle } from "../AssistantsPageTitle";
import { useInputPrompts } from "@/app/admin/prompt-library/hooks";
import { PromptSection } from "@/app/admin/prompt-library/promptSection";
export default function WrappedPrompts({
chatSessions,
initiallyToggled,
folders,
openedFolders,
}: {
chatSessions: ChatSession[];
folders: Folder[];
initiallyToggled: boolean;
openedFolders?: { [key: number]: boolean };
}) {
const {
data: promptLibrary,
error: promptLibraryError,
isLoading: promptLibraryIsLoading,
refreshInputPrompts: refreshPrompts,
} = useInputPrompts(false);
return (
<SidebarWrapper
size="lg"
page="chat"
initiallyToggled={initiallyToggled}
chatSessions={chatSessions}
folders={folders}
openedFolders={openedFolders}
>
<div className="mx-auto w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar">
<AssistantsPageTitle>Prompt Gallery</AssistantsPageTitle>
<PromptSection
promptLibrary={promptLibrary || []}
isLoading={promptLibraryIsLoading}
error={promptLibraryError}
refreshPrompts={refreshPrompts}
isPublic={false}
centering
/>
</div>
</SidebarWrapper>
);
}

View File

@@ -60,11 +60,7 @@ export function ChatBanner() {
<div className={`flex justify-center w-full overflow-hidden pr-8`}>
<div
ref={contentRef}
className={`overflow-hidden ${
settings.enterpriseSettings.two_lines_for_chat_header
? "line-clamp-2"
: "line-clamp-1"
} text-center max-w-full`}
className={`overflow-hidden ${settings.enterpriseSettings.two_lines_for_chat_header ? "line-clamp-2" : "line-clamp-1"} text-center max-w-full`}
>
<MinimalMarkdown
className="prose text-sm max-w-full"
@@ -75,11 +71,7 @@ export function ChatBanner() {
<div className="absolute top-0 left-0 invisible flex justify-center max-w-full">
<div
ref={fullContentRef}
className={`overflow-hidden invisible ${
settings.enterpriseSettings.two_lines_for_chat_header
? "line-clamp-2"
: "line-clamp-1"
} text-center max-w-full`}
className={`overflow-hidden invisible ${settings.enterpriseSettings.two_lines_for_chat_header ? "line-clamp-2" : "line-clamp-1"} text-center max-w-full`}
>
<MinimalMarkdown
className="prose text-sm max-w-full"

View File

@@ -135,6 +135,7 @@ export function ChatPage({
llmProviders,
folders,
openedFolders,
userInputPrompts,
defaultAssistantId,
shouldShowWelcomeModal,
refreshChatSessions,
@@ -470,14 +471,13 @@ export function ChatPage({
loadedSessionId != null) &&
!currentChatAnswering()
) {
updateCompleteMessageDetail(chatSession.chat_session_id, newMessageMap);
const latestMessageId =
newMessageHistory[newMessageHistory.length - 1]?.messageId;
setSelectedMessageForDocDisplay(
latestMessageId !== undefined ? latestMessageId : null
);
updateCompleteMessageDetail(chatSession.chat_session_id, newMessageMap);
}
setChatSessionSharedStatus(chatSession.shared_status);
@@ -976,7 +976,6 @@ export function ChatPage({
useEffect(() => {
if (
!personaIncludesRetrieval &&
(!selectedDocuments || selectedDocuments.length === 0) &&
documentSidebarToggled &&
!filtersToggled
@@ -1080,17 +1079,10 @@ export function ChatPage({
updateCanContinue(false, frozenSessionId);
if (currentChatState() != "input") {
if (currentChatState() == "uploading") {
setPopup({
message: "Please wait for the content to upload",
type: "error",
});
} else {
setPopup({
message: "Please wait for the response to complete",
type: "error",
});
}
setPopup({
message: "Please wait for the response to complete",
type: "error",
});
return;
}
@@ -1565,7 +1557,7 @@ export function ChatPage({
}
};
const handleImageUpload = async (acceptedFiles: File[]) => {
const handleImageUpload = (acceptedFiles: File[]) => {
const [_, llmModel] = getFinalLLM(
llmProviders,
liveAssistant,
@@ -1605,9 +1597,8 @@ export function ChatPage({
(file) => !tempFileDescriptors.some((newFile) => newFile.id === file.id)
);
};
updateChatState("uploading", currentSessionId());
await uploadFilesForChat(acceptedFiles).then(([files, error]) => {
uploadFilesForChat(acceptedFiles).then(([files, error]) => {
if (error) {
setCurrentMessageFiles((prev) => removeTempFiles(prev));
setPopup({
@@ -1618,7 +1609,6 @@ export function ChatPage({
setCurrentMessageFiles((prev) => [...removeTempFiles(prev), ...files]);
}
});
updateChatState("input", currentSessionId());
};
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
@@ -2241,7 +2231,7 @@ export function ChatPage({
ref={scrollableDivRef}
>
{liveAssistant && onAssistantChange && (
<div className="z-20 fixed top-0 pointer-events-none left-0 w-full flex justify-center overflow-visible">
<div className="z-20 fixed top-4 pointer-events-none left-0 w-full flex justify-center overflow-visible">
{!settings?.isMobile && (
<div
style={{ transition: "width 0.30s ease-out" }}
@@ -2754,6 +2744,7 @@ export function ChatPage({
chatState={currentSessionChatState}
stopGenerating={stopGenerating}
openModelSettings={() => setSettingsToggled(true)}
inputPrompts={userInputPrompts}
showDocs={() => setDocumentSelection(true)}
selectedDocuments={selectedDocuments}
// assistant stuff

View File

@@ -85,7 +85,7 @@ export const ChatFilters = forwardRef<HTMLDivElement, ChatFiltersProps>(
return (
<div
id="danswer-chat-sidebar"
className={`relative max-w-full ${
className={`relative py-2 max-w-full ${
!modal ? "border-l h-full border-sidebar-border" : ""
}`}
onClick={(e) => {

View File

@@ -2,7 +2,7 @@ import React, { useContext, useEffect, useRef, useState } from "react";
import { FiPlusCircle, FiPlus, FiInfo, FiX, FiSearch } from "react-icons/fi";
import { ChatInputOption } from "./ChatInputOption";
import { Persona } from "@/app/admin/assistants/interfaces";
import { InputPrompt } from "@/app/admin/prompt-library/interfaces";
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
import { SelectedFilterDisplay } from "./SelectedFilterDisplay";
import { useChatContext } from "@/components/context/ChatContext";
@@ -58,6 +58,7 @@ interface ChatInputBarProps {
llmOverrideManager: LlmOverrideManager;
chatState: ChatState;
alternativeAssistant: Persona | null;
inputPrompts: InputPrompt[];
// assistants
selectedAssistant: Persona;
setSelectedAssistant: (assistant: Persona) => void;
@@ -97,6 +98,7 @@ export function ChatInputBar({
textAreaRef,
alternativeAssistant,
chatSessionId,
inputPrompts,
toggleFilters,
}: ChatInputBarProps) {
useEffect(() => {
@@ -135,6 +137,7 @@ export function ChatInputBar({
const suggestionsRef = useRef<HTMLDivElement | null>(null);
const [showSuggestions, setShowSuggestions] = useState(false);
const [showPrompts, setShowPrompts] = useState(false);
const interactionsRef = useRef<HTMLDivElement | null>(null);
@@ -143,6 +146,19 @@ export function ChatInputBar({
setTabbingIconIndex(0);
};
const hidePrompts = () => {
setTimeout(() => {
setShowPrompts(false);
}, 50);
setTabbingIconIndex(0);
};
const updateInputPrompt = (prompt: InputPrompt) => {
hidePrompts();
setMessage(`${prompt.content}`);
};
useEffect(() => {
const handleClickOutside = (event: MouseEvent) => {
if (
@@ -152,6 +168,7 @@ export function ChatInputBar({
!interactionsRef.current.contains(event.target as Node))
) {
hideSuggestions();
hidePrompts();
}
};
document.addEventListener("mousedown", handleClickOutside);
@@ -181,10 +198,24 @@ export function ChatInputBar({
}
};
const handlePromptInput = (text: string) => {
if (!text.startsWith("/")) {
hidePrompts();
} else {
const promptMatch = text.match(/(?:\s|^)\/(\w*)$/);
if (promptMatch) {
setShowPrompts(true);
} else {
hidePrompts();
}
}
};
const handleInputChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
const text = event.target.value;
setMessage(text);
handleAssistantInput(text);
handlePromptInput(text);
};
const assistantTagOptions = assistantOptions.filter((assistant) =>
@@ -196,26 +227,49 @@ export function ChatInputBar({
)
);
const filteredPrompts = inputPrompts.filter(
(prompt) =>
prompt.active &&
prompt.prompt.toLowerCase().startsWith(
message
.slice(message.lastIndexOf("/") + 1)
.split(/\s/)[0]
.toLowerCase()
)
);
const [tabbingIconIndex, setTabbingIconIndex] = useState(0);
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (
showSuggestions &&
assistantTagOptions.length > 0 &&
((showSuggestions && assistantTagOptions.length > 0) || showPrompts) &&
(e.key === "Tab" || e.key == "Enter")
) {
e.preventDefault();
if (tabbingIconIndex == assistantTagOptions.length && showSuggestions) {
window.open("/assistants/new", "_self");
if (
(tabbingIconIndex == assistantTagOptions.length && showSuggestions) ||
(tabbingIconIndex == filteredPrompts.length && showPrompts)
) {
if (showPrompts) {
window.open("/prompts", "_self");
} else {
window.open("/assistants/new", "_self");
}
} else {
const option =
assistantTagOptions[tabbingIconIndex >= 0 ? tabbingIconIndex : 0];
if (showPrompts) {
const uppity =
filteredPrompts[tabbingIconIndex >= 0 ? tabbingIconIndex : 0];
updateInputPrompt(uppity);
} else {
const option =
assistantTagOptions[tabbingIconIndex >= 0 ? tabbingIconIndex : 0];
updatedTaggedAssistant(option);
updatedTaggedAssistant(option);
}
}
}
if (!showSuggestions) {
if (!showPrompts && !showSuggestions) {
return;
}
@@ -223,7 +277,10 @@ export function ChatInputBar({
e.preventDefault();
setTabbingIconIndex((tabbingIconIndex) =>
Math.min(tabbingIconIndex + 1, assistantTagOptions.length)
Math.min(
tabbingIconIndex + 1,
showPrompts ? filteredPrompts.length : assistantTagOptions.length
)
);
} else if (e.key === "ArrowUp") {
e.preventDefault();
@@ -284,6 +341,48 @@ export function ChatInputBar({
</div>
)}
{showPrompts && (
<div
ref={suggestionsRef}
className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full"
>
<div className="rounded-lg py-1.5 bg-white border border-border-medium overflow-hidden shadow-lg mx-2 px-1.5 mt-2 rounded z-10">
{filteredPrompts.map(
(currentPrompt: InputPrompt, index: number) => (
<button
key={index}
className={`px-2 ${
tabbingIconIndex == index && "bg-hover"
} rounded content-start flex gap-x-1 py-1.5 w-full hover:bg-hover cursor-pointer`}
onClick={() => {
updateInputPrompt(currentPrompt);
}}
>
<p className="font-bold">{currentPrompt.prompt}:</p>
<p className="text-left flex-grow mr-auto line-clamp-1">
{currentPrompt.id == selectedAssistant.id &&
"(default) "}
{currentPrompt.content?.trim()}
</p>
</button>
)
)}
<a
key={filteredPrompts.length}
target="_self"
className={`${
tabbingIconIndex == filteredPrompts.length && "bg-hover"
} px-3 flex gap-x-1 py-2 w-full items-center hover:bg-hover-light cursor-pointer"`}
href="/prompts"
>
<FiPlus size={17} />
<p>Create a new prompt</p>
</a>
</div>
</div>
)}
{/* <div>
<SelectedFilterDisplay filterManager={filterManager} />
</div> */}
@@ -435,6 +534,7 @@ export function ChatInputBar({
onKeyDown={(event) => {
if (
event.key === "Enter" &&
!showPrompts &&
!showSuggestions &&
!event.shiftKey &&
!(event.nativeEvent as any).isComposing

View File

@@ -1,12 +1,10 @@
import { Citation } from "@/components/search/results/Citation";
import { WebResultIcon } from "@/components/WebResultIcon";
import { LoadedDanswerDocument } from "@/lib/search/interfaces";
import { getSourceMetadata, SOURCE_METADATA_MAP } from "@/lib/sources";
import { getSourceMetadata } from "@/lib/sources";
import { ValidSources } from "@/lib/types";
import React, { memo } from "react";
import isEqual from "lodash/isEqual";
import { SlackIcon } from "@/components/icons/icons";
import { SourceIcon } from "@/components/SourceIcon";
export const MemoizedAnchor = memo(
({ docs, updatePresentingDocument, children }: any) => {
@@ -21,9 +19,19 @@ export const MemoizedAnchor = memo(
? new URL(associatedDoc.link).origin + "/favicon.ico"
: "";
const icon = (
<SourceIcon sourceType={associatedDoc?.source_type} iconSize={18} />
);
const getIcon = (sourceType: ValidSources, link: string) => {
return getSourceMetadata(sourceType).icon({ size: 18 });
};
const icon =
associatedDoc?.source_type === "web" ? (
<WebResultIcon url={associatedDoc.link} />
) : (
getIcon(
associatedDoc?.source_type || "web",
associatedDoc?.link || ""
)
);
return (
<MemoizedLink

View File

@@ -35,13 +35,6 @@ export function SetDefaultModelModal({
const container = containerRef.current;
const message = messageRef.current;
const handleEscape = (e: KeyboardEvent) => {
if (e.key === "Escape") {
onClose();
}
};
window.addEventListener("keydown", handleEscape);
if (container && message) {
const checkScrollable = () => {
if (container.scrollHeight > container.clientHeight) {
@@ -52,14 +45,9 @@ export function SetDefaultModelModal({
};
checkScrollable();
window.addEventListener("resize", checkScrollable);
return () => {
window.removeEventListener("resize", checkScrollable);
window.removeEventListener("keydown", handleEscape);
};
return () => window.removeEventListener("resize", checkScrollable);
}
return () => window.removeEventListener("keydown", handleEscape);
}, [onClose]);
}, []);
const defaultModelDestructured = defaultModel
? destructureValue(defaultModel)

View File

@@ -31,6 +31,7 @@ export default async function Page(props: {
openedFolders,
defaultAssistantId,
shouldShowWelcomeModal,
userInputPrompts,
ccPairs,
} = data;
@@ -52,6 +53,7 @@ export default async function Page(props: {
llmProviders,
folders,
openedFolders,
userInputPrompts,
shouldShowWelcomeModal,
defaultAssistantId,
}}

View File

@@ -101,7 +101,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
flex-col relative
h-screen
transition-transform
`}
pt-2`}
>
<LogoType
showArrow={true}
@@ -163,6 +163,15 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
Manage Assistants
</p>
</Link>
<Link
href="/prompts"
className="w-full p-2 bg-white border-border border rounded items-center hover:bg-background-history-sidebar-button-hover cursor-pointer transition-all duration-150 flex gap-x-2"
>
<ClosedBookIcon className="h-4 w-4 my-auto text-text-history-sidebar-button" />
<p className="my-auto flex items-center text-sm ">
Manage Prompts
</p>
</Link>
</div>
)}
<div className="border-b border-divider-history-sidebar-bar pb-4 mx-3" />

View File

@@ -22,7 +22,7 @@ export default function FixedLogo({
<>
<Link
href="/chat"
className="fixed cursor-pointer flex z-40 left-4 top-2 h-8"
className="fixed cursor-pointer flex z-40 left-2.5 top-2"
>
<div className="max-w-[200px] mobile:hidden flex items-center gap-x-1 my-auto">
<div className="flex-none my-auto">
@@ -46,8 +46,8 @@ export default function FixedLogo({
</div>
</div>
</Link>
<div className="mobile:hidden fixed left-4 bottom-4">
<FiSidebar className="text-text-mobile-sidebar" />
<div className="mobile:hidden fixed left-2.5 bottom-4">
{/* <FiSidebar className="text-text-mobile-sidebar" /> */}
</div>
</>
);

View File

@@ -1,10 +1,5 @@
export type FeedbackType = "like" | "dislike";
export type ChatState =
| "input"
| "loading"
| "streaming"
| "toolBuilding"
| "uploading";
export type ChatState = "input" | "loading" | "streaming" | "toolBuilding";
export interface RegenerationState {
regenerating: boolean;
finalMessageIndex: number;

View File

@@ -1,43 +0,0 @@
import { INTERNAL_URL } from "@/lib/constants";
import { NextRequest, NextResponse } from "next/server";
// TODO: deprecate this and just go directly to the backend via /api/...
// For some reason Egnyte doesn't work when using /api, so leaving this as is for now
// If we do try and remove this, make sure we test the Egnyte connector oauth flow
export async function GET(request: NextRequest) {
try {
const backendUrl = new URL(INTERNAL_URL);
// Copy path and query parameters from incoming request
backendUrl.pathname = request.nextUrl.pathname;
backendUrl.search = request.nextUrl.search;
const response = await fetch(backendUrl, {
method: "GET",
headers: request.headers,
body: request.body,
signal: request.signal,
// @ts-ignore
duplex: "half",
});
const responseData = await response.json();
if (responseData.redirect_url) {
return NextResponse.redirect(responseData.redirect_url);
}
return new NextResponse(JSON.stringify(responseData), {
status: response.status,
headers: response.headers,
});
} catch (error: unknown) {
console.error("Proxy error:", error);
return NextResponse.json(
{
message: "Proxy error",
error:
error instanceof Error ? error.message : "An unknown error occurred",
},
{ status: 500 }
);
}
}

View File

@@ -0,0 +1,28 @@
import { fetchChatData } from "@/lib/chat/fetchChatData";
import { unstable_noStore as noStore } from "next/cache";
import { redirect } from "next/navigation";
import WrappedPrompts from "../assistants/mine/WrappedInputPrompts";
export default async function GalleryPage(props: {
searchParams: Promise<{ [key: string]: string }>;
}) {
const searchParams = await props.searchParams;
noStore();
const data = await fetchChatData(searchParams);
if ("redirect" in data) {
redirect(data.redirect);
}
const { chatSessions, folders, openedFolders, toggleSidebar } = data;
return (
<WrappedPrompts
initiallyToggled={toggleSidebar}
chatSessions={chatSessions}
folders={folders}
openedFolders={openedFolders}
/>
);
}

View File

@@ -1,14 +0,0 @@
export default function TemporaryLoadingModal({
content,
}: {
content: string;
}) {
return (
<div className="fixed inset-0 flex items-center justify-center z-50 bg-black bg-opacity-30">
<div className="bg-white rounded-xl p-8 shadow-2xl flex items-center space-x-6">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-neutral-950"></div>
<p className="text-xl font-medium text-gray-800">{content}</p>
</div>
</div>
);
}

View File

@@ -58,7 +58,7 @@ export function ClientLayout({
return (
<div className="h-screen overflow-y-hidden">
<div className="flex h-full">
<div className="flex-none text-text-settings-sidebar bg-background-sidebar w-[250px] overflow-x-hidden z-20 pt-2 pb-8 h-full border-r border-border miniscroll overflow-auto">
<div className="flex-none text-text-settings-sidebar bg-background-sidebar w-[250px] z-20 pt-4 pb-8 h-full border-r border-border miniscroll overflow-auto">
<AdminSidebar
collections={[
{
@@ -169,6 +169,18 @@ export function ClientLayout({
),
link: "/admin/tools",
},
{
name: (
<div className="flex">
<ClosedBookIcon
className="text-icon-settings-sidebar"
size={18}
/>
<div className="ml-1">Prompt Library</div>
</div>
),
link: "/admin/prompt-library",
},
]
: []),
...(enableEnterprise
@@ -405,7 +417,7 @@ export function ClientLayout({
/>
</div>
<div className="pb-8 relative h-full overflow-y-auto w-full">
<div className="fixed left-0 gap-x-4 px-2 top-2 h-8 px-0 mb-auto w-full items-start flex justify-end">
<div className="fixed bg-background left-0 gap-x-4 mb-8 px-4 py-2 w-full items-center flex justify-end">
<UserDropdown />
</div>
<div className="pt-20 flex overflow-y-auto overflow-x-hidden h-full px-4 md:px-12">

View File

@@ -49,7 +49,7 @@ export function AccessTypeForm({
name: "Private",
value: "private",
description:
"Only users who have explicitly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector",
"Only users who have expliticly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector",
},
];

Some files were not shown because too many files have changed in this diff Show More