Compare commits

..

1 Commits

Author SHA1 Message Date
Yuhong Sun
4293543a6a k 2024-07-20 16:48:05 -07:00
25 changed files with 627 additions and 430 deletions

View File

@@ -61,7 +61,7 @@ HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
# if the title is separated out. Title is most of a "boost" than a separate field.
TITLE_CONTENT_RATIO = max(
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
)
# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator

View File

@@ -56,16 +56,6 @@ def extract_text_from_content(content: dict) -> str:
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
@@ -127,10 +117,8 @@ def fetch_jira_issues_batch(
continue
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments]
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
@@ -159,18 +147,14 @@ def fetch_jira_issues_batch(
pass
metadata_dict = {}
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
if jira.fields.priority:
metadata_dict["priority"] = jira.fields.priority.name
if jira.fields.status:
metadata_dict["status"] = jira.fields.status.name
if jira.fields.resolution:
metadata_dict["resolution"] = jira.fields.resolution.name
if jira.fields.labels:
metadata_dict["label"] = jira.fields.labels
doc_batch.append(
Document(

View File

@@ -20,10 +20,18 @@ schema DANSWER_CHUNK_NAME {
# `semantic_identifier` will be the channel name, but the `title` will be empty
field title type string {
indexing: summary | index | attribute
match {
gram
gram-size: 3
}
index: enable-bm25
}
field content type string {
indexing: summary | index
match {
gram
gram-size: 3
}
index: enable-bm25
}
# duplication of `content` is far from ideal, but is needed for
@@ -148,6 +156,7 @@ schema DANSWER_CHUNK_NAME {
function title_vector_score() {
expression {
# If no title, the full vector score comes from the content embedding
#query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
}
}

View File

@@ -4,6 +4,7 @@ from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
@@ -14,6 +15,7 @@ from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
@@ -69,6 +71,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
def embed_chunks(
self,
chunks: list[DocAwareChunk],
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[IndexChunk]:
# Cache the Title embeddings to only have to do it once
@@ -77,7 +80,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunk_texts: list[str] = []
chunk_texts = []
chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks):
chunk_texts.append(chunk.content)
@@ -89,9 +92,22 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
embeddings = self.embedding_model.encode(
chunk_texts, text_type=EmbedTextType.PASSAGE
)
# Batching for embedding
text_batches = batch_list(chunk_texts, batch_size)
embeddings: list[list[float]] = []
len_text_batches = len(text_batches)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
)
# Replace line above with the line below for easy debugging of indexing flow
# skipping the actual model
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
chunk_titles = {
chunk.source_document.get_title_for_document_index() for chunk in chunks
@@ -100,15 +116,16 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# Drop any None or empty strings
chunk_titles_list = [title for title in chunk_titles if title]
if chunk_titles_list:
# Embed Titles in batches
title_batches = batch_list(chunk_titles_list, batch_size)
len_title_batches = len(title_batches)
for ind_batch, title_batch in enumerate(title_batches, start=1):
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
title_embeddings = self.embedding_model.encode(
chunk_titles_list, text_type=EmbedTextType.PASSAGE
title_batch, text_type=EmbedTextType.PASSAGE
)
title_embed_dict.update(
{
title: vector
for title, vector in zip(chunk_titles_list, title_embeddings)
}
{title: vector for title, vector in zip(title_batch, title_embeddings)}
)
# Mapping embeddings to chunks

View File

@@ -34,7 +34,6 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
logger = setup_logger()
@@ -155,7 +154,6 @@ class SearchPipeline:
return cast(list[InferenceChunk], self._retrieved_chunks)
@log_function_time(print_only=True)
def _get_sections(self) -> list[InferenceSection]:
"""Returns an expanded section from each of the chunks.
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
@@ -175,11 +173,9 @@ class SearchPipeline:
expanded_inference_sections = []
# Full doc setting takes priority
if self.search_query.full_doc:
seen_document_ids = set()
unique_chunks = []
# This preserves the ordering since the chunks are retrieved in score order
for chunk in retrieved_chunks:
if chunk.document_id not in seen_document_ids:
@@ -199,6 +195,7 @@ class SearchPipeline:
),
)
)
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
@@ -243,35 +240,32 @@ class SearchPipeline:
merged_ranges = [
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
]
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
flattened_inference_chunks: list[InferenceChunk] = []
parallel_functions_with_args = []
flat_ranges = [r for ranges in merged_ranges for r in ranges]
for chunk_range in flat_ranges:
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
if above == below == 0:
flattened_inference_chunks.extend(chunk_range.chunks)
else:
parallel_functions_with_args.append(
functions_with_args.append(
(
# If Large Chunks are introduced, additional filters need to be added here
self.document_index.id_based_retrieval,
(
self.document_index.id_based_retrieval,
(
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
IndexFilters(access_control_list=None),
),
)
# Only need the document_id here, just use any chunk in the range is fine
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
# There is no chunk level permissioning, this expansion around chunks
# can be assumed to be safe
IndexFilters(access_control_list=None),
),
)
if parallel_functions_with_args:
list_inference_chunks = run_functions_tuples_in_parallel(
parallel_functions_with_args, allow_failures=False
)
for inference_chunks in list_inference_chunks:
flattened_inference_chunks.extend(inference_chunks)
# list of list of inference chunks where the inner list needs to be combined for content
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
flattened_inference_chunks = [
chunk for sublist in list_inference_chunks for chunk in sublist
]
doc_chunk_ind_to_chunk = {
(chunk.document_id, chunk.chunk_id): chunk

View File

@@ -5,13 +5,10 @@ from typing import Optional
from typing import TYPE_CHECKING
import requests
from httpx import HTTPError
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
@@ -106,73 +103,28 @@ class EmbeddingModel:
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def encode(
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
) -> list[list[float]]:
if not texts:
logger.warning("No texts to be embedded")
return []
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
if text_type == EmbedTextType.QUERY and self.query_prefix:
prefixed_texts = [self.query_prefix + text for text in texts]
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
prefixed_texts = [self.passage_prefix + text for text in texts]
else:
prefixed_texts = texts
if self.provider_type:
embed_request = EmbedRequest(
model_name=self.model_name,
texts=texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
return EmbedResponse(**response.json()).embeddings
embed_request = EmbedRequest(
model_name=self.model_name,
texts=prefixed_texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
)
# Batching for local embedding
text_batches = batch_list(texts, batch_size)
embeddings: list[list[float]] = []
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len(text_batches)}")
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
response.raise_for_status()
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(EmbedResponse(**response.json()).embeddings)
return embeddings
return EmbedResponse(**response.json()).embeddings
class CrossEncoderEnsembleModel:

View File

@@ -42,7 +42,7 @@ def test_embedding_configuration(
passage_prefix=None,
model_name=None,
)
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
except ValueError as e:
error_msg = f"Not a valid embedding model. Exception thrown: {e}"

View File

@@ -10,7 +10,6 @@ from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
@@ -41,131 +40,110 @@ router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
# If we are not only indexing, dont want retry very long
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
) -> Any:
if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key)
elif provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=api_key)
elif provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=api_key)
elif provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(api_key)
)
project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
else:
raise ValueError(f"Unsupported provider: {provider}")
class CloudEmbedding:
def __init__(
self,
api_key: str,
provider: str,
def __init__(self, api_key: str, provider: str, model: str | None = None):
self.api_key = api_key
# Only for Google as is needed on client setup
model: str | None = None,
) -> None:
self.model = model
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.client = _initialize_client(api_key, self.provider, model)
self.client = self._initialize_client()
def _embed_openai(self, texts: list[str], model: str | None) -> list[list[float]]:
def _initialize_client(self) -> Any:
if self.provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=self.api_key)
elif self.provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=self.api_key)
elif self.provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=self.api_key)
elif self.provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(
self.model or DEFAULT_VERTEX_MODEL
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def encode(
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
) -> list[list[float]]:
return [
self.embed(text=text, text_type=text_type, model=model_name)
for text in texts
]
def embed(
self, *, text: str, text_type: EmbedTextType, model: str | None = None
) -> list[float]:
logger.debug(f"Embedding text with provider: {self.provider}")
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _embed_openai(self, text: str, model: str | None) -> list[float]:
if model is None:
model = DEFAULT_OPENAI_MODEL
# OpenAI does not seem to provide truncation option, however
# the context lengths used by Danswer currently are smaller than the max token length
# for OpenAI embeddings so it's not a big deal
response = self.client.embeddings.create(input=texts, model=model)
return [embedding.embedding for embedding in response.data]
response = self.client.embeddings.create(input=text, model=model)
return response.data[0].embedding
def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_COHERE_MODEL
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
texts=texts,
texts=[text],
model=model,
input_type=embedding_type,
truncate="END",
)
return response.embeddings
return response.embeddings[0]
def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
# Similar to Cohere, the API server will do approximate size chunking
# it's acceptable to miss by a few tokens
response = self.client.embed(
texts,
model=model,
input_type=embedding_type,
truncation=True, # Also this is default
)
return response.embeddings
response = self.client.embed(text, model=model, input_type=embedding_type)
return response.embeddings[0]
def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_VERTEX_MODEL
embeddings = self.client.get_embeddings(
embedding = self.client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
for text in texts
],
auto_truncate=True, # Also this is default
]
)
return [embedding.values for embedding in embeddings]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed(
self,
*,
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
) -> list[list[float]]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
return embedding[0].values
@staticmethod
def create(
@@ -234,52 +212,29 @@ def embed_text(
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
prefix: str | None,
) -> list[list[float]]:
# Third party API based embedding model
if provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
raise RuntimeError("API key not provided for cloud model")
if prefix:
# This may change in the future if some providers require the user
# to manually append a prefix but this is not the case currently
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
text_type=text_type,
)
embeddings = cloud_model.encode(texts, model_name, text_type)
# Locally running model
elif model_name is not None:
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
hosted_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
)
else:
raise ValueError(
"Either model name or provider must be provided to run embeddings."
embeddings = hosted_model.encode(
texts, normalize_embeddings=normalize_embeddings
)
if embeddings is None:
raise RuntimeError("Failed to create Embeddings")
raise RuntimeError("Embeddings were not created")
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
return embeddings
@@ -297,17 +252,7 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
async def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
try:
if embed_request.text_type == EmbedTextType.QUERY:
prefix = embed_request.manual_query_prefix
elif embed_request.text_type == EmbedTextType.PASSAGE:
prefix = embed_request.manual_passage_prefix
else:
prefix = None
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
@@ -316,13 +261,13 @@ async def process_embed_request(
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
logger.exception(f"Error during embedding process:\n{str(e)}")
raise HTTPException(
status_code=500, detail="Failed to run Bi-Encoder embedding"
)
@router.post("/cross-encoder-scores")
@@ -331,11 +276,6 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
if not embed_request.documents or not embed_request.query:
raise HTTPException(
status_code=400, detail="No documents or query to be reranked"
)
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents

View File

@@ -1,7 +1,6 @@
fastapi==0.109.2
h5py==3.9.0
pydantic==1.10.13
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
tensorflow==2.15.0
@@ -10,5 +9,5 @@ transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
openai==1.14.3
cohere==5.6.1
google-cloud-aiplatform==1.58.0
cohere==5.5.8
google-cloud-aiplatform==1.58.0

View File

@@ -4,7 +4,9 @@ from shared_configs.enums import EmbedTextType
class EmbedRequest(BaseModel):
# This already includes any prefixes, the text is just passed directly to the model
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None
max_context_length: int
@@ -12,8 +14,6 @@ class EmbedRequest(BaseModel):
api_key: str | None
provider_type: str | None
text_type: EmbedTextType
manual_query_prefix: str | None
manual_passage_prefix: str | None
class EmbedResponse(BaseModel):

View File

@@ -50,7 +50,7 @@ export default async function GalleryPage({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availableAssistants: assistants,
availablePersonas: assistants,
availableTags: tags,
llmProviders,
folders,

View File

@@ -52,7 +52,7 @@ export default async function GalleryPage({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availableAssistants: assistants,
availablePersonas: assistants,
availableTags: tags,
llmProviders,
folders,

View File

@@ -63,6 +63,7 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
import { ConfigurationModal } from "./modal/configuration/ConfigurationModal";
import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid";
import { orderAssistantsForUser } from "@/lib/assistants/orderAssistants";
@@ -81,29 +82,38 @@ const SYSTEM_MESSAGE_ID = -3;
export function ChatPage({
toggle,
documentSidebarInitialWidth,
defaultSelectedAssistantId,
defaultSelectedPersonaId,
toggledSidebar,
}: {
toggle: () => void;
documentSidebarInitialWidth?: number;
defaultSelectedAssistantId?: number;
defaultSelectedPersonaId?: number;
toggledSidebar: boolean;
}) {
const router = useRouter();
const searchParams = useSearchParams();
const [configModalActiveTab, setConfigModalActiveTab] = useState<
string | null
>(null);
let {
user,
chatSessions,
availableSources,
availableDocumentSets,
availableAssistants,
availablePersonas,
llmProviders,
folders,
openedFolders,
} = useChatContext();
// chat session
const filteredAssistants = orderAssistantsForUser(availablePersonas, user);
const [selectedAssistant, setSelectedAssistant] = useState<Persona | null>(
null
);
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
useState<Persona | null>(null);
const router = useRouter();
const searchParams = useSearchParams();
const existingChatIdRaw = searchParams.get("chatId");
const existingChatSessionId = existingChatIdRaw
? parseInt(existingChatIdRaw)
@@ -113,51 +123,9 @@ export function ChatPage({
);
const chatSessionIdRef = useRef<number | null>(existingChatSessionId);
// LLM
const llmOverrideManager = useLlmOverride(selectedChatSession);
// Assistants
const filteredAssistants = orderAssistantsForUser(availableAssistants, user);
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
const [selectedAssistant, setSelectedAssistant] = useState<
Persona | undefined
>(
// NOTE: look through available assistants here, so that even if the user
// has hidden this assistant it still shows the correct assistant when
// going back to an old chat session
existingChatSessionAssistantId !== undefined
? availableAssistants.find(
(assistant) => assistant.id === existingChatSessionAssistantId
)
: defaultSelectedAssistantId !== undefined
? availableAssistants.find(
(assistant) => assistant.id === defaultSelectedAssistantId
)
: undefined
);
const setSelectedAssistantFromId = (assistantId: number) => {
// NOTE: also intentionally look through available assistants here, so that
// even if the user has hidden an assistant they can still go back to it
// for old chats
setSelectedAssistant(
availableAssistants.find((assistant) => assistant.id === assistantId)
);
};
const liveAssistant =
selectedAssistant || filteredAssistants[0] || availableAssistants[0];
// this is for "@"ing assistants
const [alternativeAssistant, setAlternativeAssistant] =
useState<Persona | null>(null);
// this is used to track which assistant is being used to generate the current message
// for example, this would come into play when:
// 1. default assistant is `Danswer`
// 2. we "@"ed the `GPT` assistant and sent a message
// 3. while the `GPT` assistant message is generating, we "@" the `Paraphrase` assistant
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
useState<Persona | null>(null);
const existingChatSessionPersonaId = selectedChatSession?.persona_id;
// used to track whether or not the initial "submit on load" has been performed
// this only applies if `?submit-on-load=true` or `?submit-on-load=1` is in the URL
@@ -214,10 +182,14 @@ export function ChatPage({
async function initialSessionFetch() {
if (existingChatSessionId === null) {
setIsFetchingChatMessages(false);
if (defaultSelectedAssistantId !== undefined) {
setSelectedAssistantFromId(defaultSelectedAssistantId);
if (defaultSelectedPersonaId !== undefined) {
setSelectedPersona(
filteredAssistants.find(
(persona) => persona.id === defaultSelectedPersonaId
)
);
} else {
setSelectedAssistant(undefined);
setSelectedPersona(undefined);
}
setCompleteMessageDetail({
sessionId: null,
@@ -242,7 +214,12 @@ export function ChatPage({
);
const chatSession = (await response.json()) as BackendChatSession;
setSelectedAssistantFromId(chatSession.persona_id);
setSelectedPersona(
filteredAssistants.find(
(persona) => persona.id === chatSession.persona_id
)
);
const newMessageMap = processRawChatHistory(chatSession.messages);
const newMessageHistory = buildLatestMessageChain(newMessageMap);
@@ -396,18 +373,32 @@ export function ChatPage({
)
: { aiMessage: null };
const [selectedPersona, setSelectedPersona] = useState<Persona | undefined>(
existingChatSessionPersonaId !== undefined
? filteredAssistants.find(
(persona) => persona.id === existingChatSessionPersonaId
)
: defaultSelectedPersonaId !== undefined
? filteredAssistants.find(
(persona) => persona.id === defaultSelectedPersonaId
)
: undefined
);
const livePersona =
selectedPersona || filteredAssistants[0] || availablePersonas[0];
const [chatSessionSharedStatus, setChatSessionSharedStatus] =
useState<ChatSessionSharedStatus>(ChatSessionSharedStatus.Private);
useEffect(() => {
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
setSelectedAssistant(
setSelectedPersona(
filteredAssistants.find(
(persona) => persona.id === defaultSelectedAssistantId
(persona) => persona.id === defaultSelectedPersonaId
)
);
}
}, [defaultSelectedAssistantId]);
}, [defaultSelectedPersonaId]);
const [
selectedDocuments,
@@ -423,7 +414,7 @@ export function ChatPage({
useEffect(() => {
async function fetchMaxTokens() {
const response = await fetch(
`/api/chat/max-selected-document-tokens?persona_id=${liveAssistant.id}`
`/api/chat/max-selected-document-tokens?persona_id=${livePersona.id}`
);
if (response.ok) {
const maxTokens = (await response.json()).max_tokens as number;
@@ -432,12 +423,12 @@ export function ChatPage({
}
fetchMaxTokens();
}, [liveAssistant]);
}, [livePersona]);
const filterManager = useFilters();
const [finalAvailableSources, finalAvailableDocumentSets] =
computeAvailableFilters({
selectedPersona: selectedAssistant,
selectedPersona,
availableSources,
availableDocumentSets,
});
@@ -633,16 +624,16 @@ export function ChatPage({
queryOverride,
forceSearch,
isSeededChat,
alternativeAssistantOverride = null,
alternativeAssistant = null,
}: {
messageIdToResend?: number;
messageOverride?: string;
queryOverride?: string;
forceSearch?: boolean;
isSeededChat?: boolean;
alternativeAssistantOverride?: Persona | null;
alternativeAssistant?: Persona | null;
} = {}) => {
setAlternativeGeneratingAssistant(alternativeAssistantOverride);
setAlternativeGeneratingAssistant(alternativeAssistant);
clientScrollToBottom();
let currChatSessionId: number;
@@ -652,7 +643,7 @@ export function ChatPage({
if (isNewSession) {
currChatSessionId = await createChatSession(
liveAssistant?.id || 0,
livePersona?.id || 0,
searchParamBasedChatSessionName
);
} else {
@@ -730,9 +721,9 @@ export function ChatPage({
parentMessage = frozenMessageMap.get(SYSTEM_MESSAGE_ID) || null;
}
const currentAssistantId = alternativeAssistantOverride
? alternativeAssistantOverride.id
: alternativeAssistant?.id || liveAssistant.id;
const currentAssistantId = alternativeAssistant
? alternativeAssistant.id
: selectedAssistant?.id;
resetInputBar();
@@ -760,7 +751,7 @@ export function ChatPage({
fileDescriptors: currentMessageFiles,
parentMessageId: lastSuccessfulMessageId,
chatSessionId: currChatSessionId,
promptId: liveAssistant?.prompts[0]?.id || 0,
promptId: livePersona?.prompts[0]?.id || 0,
filters: buildFilters(
filterManager.selectedSources,
filterManager.selectedDocumentSets,
@@ -877,7 +868,7 @@ export function ChatPage({
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
parentMessageId: newUserMessageId,
alternateAssistantID: alternativeAssistant?.id,
alternateAssistantID: selectedAssistant?.id,
},
]);
}
@@ -973,23 +964,19 @@ export function ChatPage({
}
};
const onAssistantChange = (assistant: Persona | null) => {
if (assistant && assistant.id !== liveAssistant.id) {
const onPersonaChange = (persona: Persona | null) => {
if (persona && persona.id !== livePersona.id) {
// remove uploaded files
setCurrentMessageFiles([]);
setSelectedAssistant(assistant);
setSelectedPersona(persona);
textAreaRef.current?.focus();
router.push(buildChatUrl(searchParams, null, assistant.id));
router.push(buildChatUrl(searchParams, null, persona.id));
}
};
const handleImageUpload = (acceptedFiles: File[]) => {
const llmAcceptsImages = checkLLMSupportsImageInput(
...getFinalLLM(
llmProviders,
liveAssistant,
llmOverrideManager.llmOverride
)
...getFinalLLM(llmProviders, livePersona, llmOverrideManager.llmOverride)
);
const imageFiles = acceptedFiles.filter((file) =>
file.type.startsWith("image/")
@@ -1071,23 +1058,23 @@ export function ChatPage({
useEffect(() => {
const includes = checkAnyAssistantHasSearch(
messageHistory,
availableAssistants,
liveAssistant
availablePersonas,
livePersona
);
setRetrievalEnabled(includes);
}, [messageHistory, availableAssistants, liveAssistant]);
}, [messageHistory, availablePersonas, livePersona]);
const [retrievalEnabled, setRetrievalEnabled] = useState(() => {
return checkAnyAssistantHasSearch(
messageHistory,
availableAssistants,
liveAssistant
availablePersonas,
livePersona
);
});
const innerSidebarElementRef = useRef<HTMLDivElement>(null);
const currentPersona = alternativeAssistant || liveAssistant;
const currentPersona = selectedAssistant || livePersona;
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
@@ -1189,8 +1176,21 @@ export function ChatPage({
/>
)}
<ConfigurationModal
chatSessionId={chatSessionIdRef.current!}
activeTab={configModalActiveTab}
setActiveTab={setConfigModalActiveTab}
onClose={() => setConfigModalActiveTab(null)}
filterManager={filterManager}
availableAssistants={filteredAssistants}
selectedAssistant={livePersona}
setSelectedAssistant={onPersonaChange}
llmProviders={llmProviders}
llmOverrideManager={llmOverrideManager}
/>
<div className="flex h-[calc(100dvh)] flex-col w-full">
{liveAssistant && (
{livePersona && (
<FunctionalHeader
page="chat"
setSharingModalVisible={
@@ -1239,7 +1239,7 @@ export function ChatPage({
!isStreaming && (
<ChatIntro
availableSources={finalAvailableSources}
selectedPersona={liveAssistant}
selectedPersona={livePersona}
/>
)}
<div
@@ -1319,7 +1319,7 @@ export function ChatPage({
const currentAlternativeAssistant =
message.alternateAssistantID != null
? availableAssistants.find(
? availablePersonas.find(
(persona) =>
persona.id ==
message.alternateAssistantID
@@ -1342,7 +1342,7 @@ export function ChatPage({
toggleDocumentSelectionAspects
}
docs={message.documents}
currentPersona={liveAssistant}
currentPersona={livePersona}
alternativeAssistant={
currentAlternativeAssistant
}
@@ -1352,7 +1352,7 @@ export function ChatPage({
query={
messageHistory[i]?.query || undefined
}
personaName={liveAssistant.name}
personaName={livePersona.name}
citedDocuments={getCitedDocumentsFromMessage(
message
)}
@@ -1404,7 +1404,7 @@ export function ChatPage({
messageIdToResend:
previousMessage.messageId,
queryOverride: newQuery,
alternativeAssistantOverride:
alternativeAssistant:
currentAlternativeAssistant,
});
}
@@ -1435,7 +1435,7 @@ export function ChatPage({
messageIdToResend:
previousMessage.messageId,
forceSearch: true,
alternativeAssistantOverride:
alternativeAssistant:
currentAlternativeAssistant,
});
} else {
@@ -1460,9 +1460,9 @@ export function ChatPage({
return (
<div key={messageReactComponentKey}>
<AIMessage
currentPersona={liveAssistant}
currentPersona={livePersona}
messageId={message.messageId}
personaName={liveAssistant.name}
personaName={livePersona.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
@@ -1481,13 +1481,13 @@ export function ChatPage({
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
>
<AIMessage
currentPersona={liveAssistant}
currentPersona={livePersona}
alternativeAssistant={
alternativeGeneratingAssistant ??
alternativeAssistant
selectedAssistant
}
messageId={null}
personaName={liveAssistant.name}
personaName={livePersona.name}
content={
<div className="text-sm my-auto">
<ThreeDots
@@ -1513,7 +1513,7 @@ export function ChatPage({
{currentPersona &&
currentPersona.starter_messages &&
currentPersona.starter_messages.length > 0 &&
selectedAssistant &&
selectedPersona &&
messageHistory.length === 0 &&
!isFetchingChatMessages && (
<div
@@ -1570,25 +1570,32 @@ export function ChatPage({
<ChatInputBar
showDocs={() => setDocumentSelection(true)}
selectedDocuments={selectedDocuments}
// assistant stuff
assistantOptions={filteredAssistants}
selectedAssistant={liveAssistant}
setSelectedAssistant={onAssistantChange}
setAlternativeAssistant={setAlternativeAssistant}
alternativeAssistant={alternativeAssistant}
// end assistant stuff
setSelectedAssistant={onPersonaChange}
onSetSelectedAssistant={(
alternativeAssistant: Persona | null
) => {
setSelectedAssistant(alternativeAssistant);
}}
alternativeAssistant={selectedAssistant}
personas={filteredAssistants}
message={message}
setMessage={setMessage}
onSubmit={onSubmit}
isStreaming={isStreaming}
setIsCancelled={setIsCancelled}
retrievalDisabled={
!personaIncludesRetrieval(currentPersona)
}
filterManager={filterManager}
llmOverrideManager={llmOverrideManager}
selectedAssistant={livePersona}
files={currentMessageFiles}
setFiles={setCurrentMessageFiles}
handleFileUpload={handleImageUpload}
setConfigModalActiveTab={setConfigModalActiveTab}
textAreaRef={textAreaRef}
chatSessionId={chatSessionIdRef.current!}
availableAssistants={availablePersonas}
/>
</div>
</div>

View File

@@ -5,10 +5,10 @@ import { ChatPage } from "./ChatPage";
import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper";
export default function WrappedChat({
defaultAssistantId,
defaultPersonaId,
initiallyToggled,
}: {
defaultAssistantId?: number;
defaultPersonaId?: number;
initiallyToggled: boolean;
}) {
return (
@@ -17,7 +17,7 @@ export default function WrappedChat({
content={(toggledSidebar, toggle) => (
<ChatPage
toggle={toggle}
defaultSelectedAssistantId={defaultAssistantId}
defaultSelectedPersonaId={defaultPersonaId}
toggledSidebar={toggledSidebar}
/>
)}

View File

@@ -21,6 +21,7 @@ import { IconType } from "react-icons";
import Popup from "../../../components/popup/Popup";
import { LlmTab } from "../modal/configuration/LlmTab";
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
import ChatInputAssistant from "./ChatInputAssistant";
import { DanswerDocument } from "@/lib/search/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Tooltip } from "@/components/tooltip/Tooltip";
@@ -28,6 +29,7 @@ import { Hoverable } from "@/components/Hoverable";
const MAX_INPUT_HEIGHT = 200;
export function ChatInputBar({
personas,
showDocs,
selectedDocuments,
message,
@@ -35,32 +37,34 @@ export function ChatInputBar({
onSubmit,
isStreaming,
setIsCancelled,
retrievalDisabled,
filterManager,
llmOverrideManager,
// assistants
onSetSelectedAssistant,
selectedAssistant,
assistantOptions,
setSelectedAssistant,
setAlternativeAssistant,
files,
setSelectedAssistant,
setFiles,
handleFileUpload,
setConfigModalActiveTab,
textAreaRef,
alternativeAssistant,
chatSessionId,
availableAssistants,
}: {
showDocs: () => void;
selectedDocuments: DanswerDocument[];
assistantOptions: Persona[];
setAlternativeAssistant: (alternativeAssistant: Persona | null) => void;
availableAssistants: Persona[];
onSetSelectedAssistant: (alternativeAssistant: Persona | null) => void;
setSelectedAssistant: (assistant: Persona) => void;
personas: Persona[];
message: string;
setMessage: (message: string) => void;
onSubmit: () => void;
isStreaming: boolean;
setIsCancelled: (value: boolean) => void;
retrievalDisabled: boolean;
filterManager: FilterManager;
llmOverrideManager: LlmOverrideManager;
selectedAssistant: Persona;
@@ -68,6 +72,7 @@ export function ChatInputBar({
files: FileDescriptor[];
setFiles: (files: FileDescriptor[]) => void;
handleFileUpload: (files: File[]) => void;
setConfigModalActiveTab: (tab: string) => void;
textAreaRef: React.RefObject<HTMLTextAreaElement>;
chatSessionId?: number;
}) {
@@ -131,10 +136,8 @@ export function ChatInputBar({
};
// Update selected persona
const updatedTaggedAssistant = (assistant: Persona) => {
setAlternativeAssistant(
assistant.id == selectedAssistant.id ? null : assistant
);
const updateCurrentPersona = (persona: Persona) => {
onSetSelectedAssistant(persona.id == selectedAssistant.id ? null : persona);
hideSuggestions();
setMessage("");
};
@@ -157,8 +160,8 @@ export function ChatInputBar({
}
};
const assistantTagOptions = assistantOptions.filter((assistant) =>
assistant.name.toLowerCase().startsWith(
const filteredPersonas = personas.filter((persona) =>
persona.name.toLowerCase().startsWith(
message
.slice(message.lastIndexOf("@") + 1)
.split(/\s/)[0]
@@ -171,18 +174,18 @@ export function ChatInputBar({
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (
showSuggestions &&
assistantTagOptions.length > 0 &&
filteredPersonas.length > 0 &&
(e.key === "Tab" || e.key == "Enter")
) {
e.preventDefault();
if (assistantIconIndex == assistantTagOptions.length) {
if (assistantIconIndex == filteredPersonas.length) {
window.open("/assistants/new", "_blank");
hideSuggestions();
setMessage("");
} else {
const option =
assistantTagOptions[assistantIconIndex >= 0 ? assistantIconIndex : 0];
updatedTaggedAssistant(option);
filteredPersonas[assistantIconIndex >= 0 ? assistantIconIndex : 0];
updateCurrentPersona(option);
}
}
if (!showSuggestions) {
@@ -192,7 +195,7 @@ export function ChatInputBar({
if (e.key === "ArrowDown") {
e.preventDefault();
setAssistantIconIndex((assistantIconIndex) =>
Math.min(assistantIconIndex + 1, assistantTagOptions.length)
Math.min(assistantIconIndex + 1, filteredPersonas.length)
);
} else if (e.key === "ArrowUp") {
e.preventDefault();
@@ -216,36 +219,35 @@ export function ChatInputBar({
mx-auto
"
>
{showSuggestions && assistantTagOptions.length > 0 && (
{showSuggestions && filteredPersonas.length > 0 && (
<div
ref={suggestionsRef}
className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full"
>
<div className="rounded-lg py-1.5 bg-background border border-border-medium shadow-lg mx-2 px-1.5 mt-2 rounded z-10">
{assistantTagOptions.map((currentAssistant, index) => (
{filteredPersonas.map((currentPersona, index) => (
<button
key={index}
className={`px-2 ${
assistantIconIndex == index && "bg-hover-lightish"
} rounded rounded-lg content-start flex gap-x-1 py-2 w-full hover:bg-hover-lightish cursor-pointer`}
onClick={() => {
updatedTaggedAssistant(currentAssistant);
updateCurrentPersona(currentPersona);
}}
>
<p className="font-bold">{currentAssistant.name}</p>
<p className="font-bold">{currentPersona.name}</p>
<p className="line-clamp-1">
{currentAssistant.id == selectedAssistant.id &&
{currentPersona.id == selectedAssistant.id &&
"(default) "}
{currentAssistant.description}
{currentPersona.description}
</p>
</button>
))}
<a
key={assistantTagOptions.length}
key={filteredPersonas.length}
target="_blank"
className={`${
assistantIconIndex == assistantTagOptions.length &&
"bg-hover"
assistantIconIndex == filteredPersonas.length && "bg-hover"
} rounded rounded-lg px-3 flex gap-x-1 py-2 w-full items-center hover:bg-hover-lightish cursor-pointer"`}
href="/assistants/new"
>
@@ -299,7 +301,7 @@ export function ChatInputBar({
<Hoverable
icon={FiX}
onClick={() => setAlternativeAssistant(null)}
onClick={() => onSetSelectedAssistant(null)}
/>
</div>
</div>
@@ -407,7 +409,7 @@ export function ChatInputBar({
removePadding
content={(close) => (
<AssistantsTab
availableAssistants={assistantOptions}
availableAssistants={availableAssistants}
llmProviders={llmProviders}
selectedAssistant={selectedAssistant}
onSelect={(assistant) => {

View File

@@ -505,7 +505,7 @@ export function removeMessage(
export function checkAnyAssistantHasSearch(
messageHistory: Message[],
availableAssistants: Persona[],
availablePersonas: Persona[],
livePersona: Persona
): boolean {
const response =
@@ -516,8 +516,8 @@ export function checkAnyAssistantHasSearch(
) {
return false;
}
const alternateAssistant = availableAssistants.find(
(assistant) => assistant.id === message.alternateAssistantID
const alternateAssistant = availablePersonas.find(
(persona) => persona.id === message.alternateAssistantID
);
return alternateAssistant
? personaIncludesRetrieval(alternateAssistant)

View File

@@ -798,7 +798,7 @@ export const HumanMessage = ({
!isEditing &&
(!files || files.length === 0)
) && "ml-auto"
} relative max-w-[70%] mb-auto whitespace-break-spaces rounded-3xl bg-user px-5 py-2.5`}
} relative max-w-[70%] mb-auto rounded-3xl bg-user px-5 py-2.5`}
>
{content}
</div>

View File

@@ -62,10 +62,11 @@ export function AssistantsTab({
toolName = "Image Generation";
toolIcon = <FiImage className="mr-1 my-auto" />;
}
return (
<Bubble key={tool.id} isSelected={false}>
<div className="flex line-wrap break-all flex-row gap-1">
<div className="flex-none my-auto">{toolIcon}</div>
<div className="flex flex-row gap-1">
{toolIcon}
{toolName}
</div>
</Bubble>

View File

@@ -0,0 +1,180 @@
"use client";
import React, { useEffect } from "react";
import { Modal } from "../../../../components/Modal";
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
import { FiltersTab } from "./FiltersTab";
import { FiCpu, FiFilter, FiX } from "react-icons/fi";
import { IconType } from "react-icons";
import { FaBrain } from "react-icons/fa";
import { AssistantsTab } from "./AssistantsTab";
import { Persona } from "@/app/admin/assistants/interfaces";
import { LlmTab } from "./LlmTab";
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
import { AssistantsIcon, IconProps } from "@/components/icons/icons";
const TabButton = ({
label,
icon: Icon,
isActive,
onClick,
}: {
label: string;
icon: IconType;
isActive: boolean;
onClick: () => void;
}) => (
<button
onClick={onClick}
className={`
pb-4
pt-6
px-2
text-emphasis
font-bold
${isActive ? "border-b-2 border-accent" : ""}
hover:bg-hover-light
hover:text-strong
transition
duration-200
ease-in-out
flex
`}
>
<Icon className="inline-block mr-2 my-auto" size="16" />
<p className="my-auto">{label}</p>
</button>
);
export function ConfigurationModal({
activeTab,
setActiveTab,
onClose,
availableAssistants,
selectedAssistant,
setSelectedAssistant,
filterManager,
llmProviders,
llmOverrideManager,
chatSessionId,
}: {
activeTab: string | null;
setActiveTab: (tab: string | null) => void;
onClose: () => void;
availableAssistants: Persona[];
selectedAssistant: Persona;
setSelectedAssistant: (assistant: Persona) => void;
filterManager: FilterManager;
llmProviders: LLMProviderDescriptor[];
llmOverrideManager: LlmOverrideManager;
chatSessionId?: number;
}) {
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === "Escape") {
onClose();
}
};
document.addEventListener("keydown", handleKeyDown);
return () => {
document.removeEventListener("keydown", handleKeyDown);
};
}, [onClose]);
if (!activeTab) return null;
return (
<Modal
onOutsideClick={onClose}
noPadding
className="
w-4/6
h-4/6
flex
flex-col
"
>
<div className="rounded flex flex-col overflow-hidden">
<div className="mb-4">
<div className="flex border-b border-border bg-background-emphasis">
<div className="flex px-6 gap-x-2">
<TabButton
label="Assistants"
icon={FaBrain}
isActive={activeTab === "assistants"}
onClick={() => setActiveTab("assistants")}
/>
<TabButton
label="Models"
icon={FiCpu}
isActive={activeTab === "llms"}
onClick={() => setActiveTab("llms")}
/>
<TabButton
label="Filters"
icon={FiFilter}
isActive={activeTab === "filters"}
onClick={() => setActiveTab("filters")}
/>
</div>
<button
className="
ml-auto
px-1
py-1
text-xs
font-medium
rounded
hover:bg-hover
focus:outline-none
focus:ring-2
focus:ring-offset-2
focus:ring-subtle
flex
items-center
h-fit
my-auto
mr-5
"
onClick={onClose}
>
<FiX size={24} />
</button>
</div>
</div>
<div className="flex flex-col overflow-y-auto">
<div className="px-8 pt-4">
{activeTab === "filters" && (
<FiltersTab filterManager={filterManager} />
)}
{activeTab === "llms" && (
<LlmTab
chatSessionId={chatSessionId}
llmOverrideManager={llmOverrideManager}
currentAssistant={selectedAssistant}
/>
)}
{activeTab === "assistants" && (
<div>
<AssistantsTab
availableAssistants={availableAssistants}
llmProviders={llmProviders}
selectedAssistant={selectedAssistant}
onSelect={(assistant) => {
setSelectedAssistant(assistant);
onClose();
}}
/>
</div>
)}
</div>
</div>
</div>
</Modal>
);
}

View File

@@ -35,7 +35,7 @@ export default async function Page({
folders,
toggleSidebar,
openedFolders,
defaultAssistantId,
defaultPersonaId,
finalDocumentSidebarInitialWidth,
shouldShowWelcomeModal,
shouldDisplaySourcesIncompleteModal,
@@ -58,7 +58,7 @@ export default async function Page({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availableAssistants: assistants,
availablePersonas: assistants,
availableTags: tags,
llmProviders,
folders,
@@ -66,7 +66,7 @@ export default async function Page({
}}
>
<WrappedChat
defaultAssistantId={defaultAssistantId}
defaultPersonaId={defaultPersonaId}
initiallyToggled={toggleSidebar}
/>
</ChatProvider>

View File

@@ -0,0 +1,111 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { BasicSelectable } from "@/components/BasicClickable";
import { AssistantsIcon } from "@/components/icons/icons";
import { User } from "@/lib/types";
import { Text } from "@tremor/react";
import Link from "next/link";
import { FaRobot } from "react-icons/fa";
import { FiEdit2 } from "react-icons/fi";
function AssistantDisplay({
persona,
onSelect,
user,
}: {
persona: Persona;
onSelect: (persona: Persona) => void;
user: User | null;
}) {
const isEditable =
(!user || user.id === persona.owner?.id) &&
!persona.default_persona &&
(!persona.is_public || !user || user.role === "admin");
return (
<div className="flex">
<div className="w-full" onClick={() => onSelect(persona)}>
<BasicSelectable selected={false} fullWidth>
<div className="flex">
<div className="truncate w-48 3xl:w-56 flex">
<AssistantsIcon className="mr-2 my-auto" size={16} />{" "}
{persona.name}
</div>
</div>
</BasicSelectable>
</div>
{isEditable && (
<div className="pl-2 my-auto">
<Link href={`/assistants/edit/${persona.id}`}>
<FiEdit2
className="my-auto ml-auto hover:bg-hover p-0.5"
size={20}
/>
</Link>
</div>
)}
</div>
);
}
export function AssistantsTab({
personas,
onPersonaChange,
user,
}: {
personas: Persona[];
onPersonaChange: (persona: Persona | null) => void;
user: User | null;
}) {
const globalAssistants = personas.filter((persona) => persona.is_public);
const personalAssistants = personas.filter(
(persona) =>
(!user || persona.users.some((u) => u.id === user.id)) &&
!persona.is_public
);
return (
<div className="mt-4 pb-1 overflow-y-auto h-full flex flex-col gap-y-1">
<Text className="mx-3 text-xs mb-4">
Select an Assistant below to begin a new chat with them!
</Text>
<div className="mx-3">
{globalAssistants.length > 0 && (
<>
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 font-bold">
Global
</div>
{globalAssistants.map((persona) => {
return (
<AssistantDisplay
key={persona.id}
persona={persona}
onSelect={onPersonaChange}
user={user}
/>
);
})}
</>
)}
{personalAssistants.length > 0 && (
<>
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 mt-5 font-bold">
Personal
</div>
{personalAssistants.map((persona) => {
return (
<AssistantDisplay
key={persona.id}
persona={persona}
onSelect={onPersonaChange}
user={user}
/>
);
})}
</>
)}
</div>
</div>
);
}

View File

@@ -11,7 +11,7 @@ export default function FixedLogo() {
const enterpriseSettings = combinedSettings?.enterpriseSettings;
return (
<div className="absolute flex z-40 left-4 top-2">
<div className="fixed flex z-40 left-4 top-2">
{" "}
<a href="/chat" className="ml-7 text-text-700 text-xl">
<div>

View File

@@ -18,7 +18,7 @@ interface ChatContextProps {
chatSessions: ChatSession[];
availableSources: ValidSources[];
availableDocumentSets: DocumentSet[];
availableAssistants: Persona[];
availablePersonas: Persona[];
availableTags: Tag[];
llmProviders: LLMProviderDescriptor[];
folders: Folder[];

View File

@@ -19,7 +19,7 @@ export function Citation({
return (
<CustomTooltip
citation
content={<div className="inline-block p-0 m-0 truncate">{link}</div>}
content={<p className="inline-block p-0 m-0 truncate">{link}</p>}
>
<a
onClick={() => (link ? window.open(link, "_blank") : undefined)}

View File

@@ -30,6 +30,7 @@ import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
interface FetchChatDataResult {
user: User | null;
chatSessions: ChatSession[];
ccPairs: CCPairBasicInfo[];
availableSources: ValidSources[];
documentSets: DocumentSet[];
@@ -38,7 +39,7 @@ interface FetchChatDataResult {
llmProviders: LLMProviderDescriptor[];
folders: Folder[];
openedFolders: Record<string, boolean>;
defaultAssistantId?: number;
defaultPersonaId?: number;
toggleSidebar: boolean;
finalDocumentSidebarInitialWidth?: number;
shouldShowWelcomeModal: boolean;
@@ -149,9 +150,9 @@ export async function fetchChatData(searchParams: {
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
}
const defaultAssistantIdRaw = searchParams["assistantId"];
const defaultAssistantId = defaultAssistantIdRaw
? parseInt(defaultAssistantIdRaw)
const defaultPersonaIdRaw = searchParams["assistantId"];
const defaultPersonaId = defaultPersonaIdRaw
? parseInt(defaultPersonaIdRaw)
: undefined;
const documentSidebarCookieInitialWidth = cookies().get(
@@ -208,7 +209,7 @@ export async function fetchChatData(searchParams: {
llmProviders,
folders,
openedFolders,
defaultAssistantId,
defaultPersonaId,
finalDocumentSidebarInitialWidth,
toggleSidebar,
shouldShowWelcomeModal,