mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 16:55:46 +00:00
Compare commits
1 Commits
eval/combi
...
eval/split
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4293543a6a |
@@ -61,7 +61,7 @@ HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
# if the title is separated out. Title is most of a "boost" than a separate field.
|
||||
TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
|
||||
)
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
|
||||
@@ -56,16 +56,6 @@ def extract_text_from_content(content: dict) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
@@ -127,10 +117,8 @@ def fetch_jira_issues_batch(
|
||||
continue
|
||||
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments]
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
@@ -159,18 +147,14 @@ def fetch_jira_issues_batch(
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
if jira.fields.priority:
|
||||
metadata_dict["priority"] = jira.fields.priority.name
|
||||
if jira.fields.status:
|
||||
metadata_dict["status"] = jira.fields.status.name
|
||||
if jira.fields.resolution:
|
||||
metadata_dict["resolution"] = jira.fields.resolution.name
|
||||
if jira.fields.labels:
|
||||
metadata_dict["label"] = jira.fields.labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
|
||||
@@ -20,10 +20,18 @@ schema DANSWER_CHUNK_NAME {
|
||||
# `semantic_identifier` will be the channel name, but the `title` will be empty
|
||||
field title type string {
|
||||
indexing: summary | index | attribute
|
||||
match {
|
||||
gram
|
||||
gram-size: 3
|
||||
}
|
||||
index: enable-bm25
|
||||
}
|
||||
field content type string {
|
||||
indexing: summary | index
|
||||
match {
|
||||
gram
|
||||
gram-size: 3
|
||||
}
|
||||
index: enable-bm25
|
||||
}
|
||||
# duplication of `content` is far from ideal, but is needed for
|
||||
@@ -148,6 +156,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
function title_vector_score() {
|
||||
expression {
|
||||
# If no title, the full vector score comes from the content embedding
|
||||
#query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
|
||||
if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ from abc import abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
@@ -14,6 +15,7 @@ from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
@@ -69,6 +71,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[IndexChunk]:
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
@@ -77,7 +80,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunk_texts: list[str] = []
|
||||
chunk_texts = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
@@ -89,9 +92,22 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
|
||||
embeddings = self.embedding_model.encode(
|
||||
chunk_texts, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
# Batching for embedding
|
||||
text_batches = batch_list(chunk_texts, batch_size)
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
len_text_batches = len(text_batches)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(
|
||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||
)
|
||||
|
||||
# Replace line above with the line below for easy debugging of indexing flow
|
||||
# skipping the actual model
|
||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||
|
||||
chunk_titles = {
|
||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||
@@ -100,15 +116,16 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# Drop any None or empty strings
|
||||
chunk_titles_list = [title for title in chunk_titles if title]
|
||||
|
||||
if chunk_titles_list:
|
||||
# Embed Titles in batches
|
||||
title_batches = batch_list(chunk_titles_list, batch_size)
|
||||
len_title_batches = len(title_batches)
|
||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
title_batch, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{
|
||||
title: vector
|
||||
for title, vector in zip(chunk_titles_list, title_embeddings)
|
||||
}
|
||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||
)
|
||||
|
||||
# Mapping embeddings to chunks
|
||||
|
||||
@@ -34,7 +34,6 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -155,7 +154,6 @@ class SearchPipeline:
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
|
||||
@@ -175,11 +173,9 @@ class SearchPipeline:
|
||||
expanded_inference_sections = []
|
||||
|
||||
# Full doc setting takes priority
|
||||
|
||||
if self.search_query.full_doc:
|
||||
seen_document_ids = set()
|
||||
unique_chunks = []
|
||||
|
||||
# This preserves the ordering since the chunks are retrieved in score order
|
||||
for chunk in retrieved_chunks:
|
||||
if chunk.document_id not in seen_document_ids:
|
||||
@@ -199,6 +195,7 @@ class SearchPipeline:
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
@@ -243,35 +240,32 @@ class SearchPipeline:
|
||||
merged_ranges = [
|
||||
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
|
||||
]
|
||||
|
||||
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
|
||||
flattened_inference_chunks: list[InferenceChunk] = []
|
||||
parallel_functions_with_args = []
|
||||
flat_ranges = [r for ranges in merged_ranges for r in ranges]
|
||||
|
||||
for chunk_range in flat_ranges:
|
||||
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
|
||||
if above == below == 0:
|
||||
flattened_inference_chunks.extend(chunk_range.chunks)
|
||||
|
||||
else:
|
||||
parallel_functions_with_args.append(
|
||||
functions_with_args.append(
|
||||
(
|
||||
# If Large Chunks are introduced, additional filters need to be added here
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
# Only need the document_id here, just use any chunk in the range is fine
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
# There is no chunk level permissioning, this expansion around chunks
|
||||
# can be assumed to be safe
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
|
||||
if parallel_functions_with_args:
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
parallel_functions_with_args, allow_failures=False
|
||||
)
|
||||
for inference_chunks in list_inference_chunks:
|
||||
flattened_inference_chunks.extend(inference_chunks)
|
||||
|
||||
# list of list of inference chunks where the inner list needs to be combined for content
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
flattened_inference_chunks = [
|
||||
chunk for sublist in list_inference_chunks for chunk in sublist
|
||||
]
|
||||
|
||||
doc_chunk_ind_to_chunk = {
|
||||
(chunk.document_id, chunk.chunk_id): chunk
|
||||
|
||||
@@ -5,13 +5,10 @@ from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from httpx import HTTPError
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@@ -106,73 +103,28 @@ class EmbeddingModel:
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
) -> list[list[float]]:
|
||||
if not texts:
|
||||
logger.warning("No texts to be embedded")
|
||||
return []
|
||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
||||
prefixed_texts = [self.query_prefix + text for text in texts]
|
||||
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
|
||||
prefixed_texts = [self.passage_prefix + text for text in texts]
|
||||
else:
|
||||
prefixed_texts = texts
|
||||
|
||||
if self.provider_type:
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[list[float]] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len(text_batches)}")
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
response.raise_for_status()
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||
|
||||
return embeddings
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_embedding_configuration(
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
|
||||
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||
|
||||
@@ -10,7 +10,6 @@ from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
@@ -41,131 +40,110 @@ router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
|
||||
def _initialize_client(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
) -> Any:
|
||||
if provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(api_key)
|
||||
)
|
||||
project_id = json.loads(api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: str,
|
||||
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
||||
self.api_key = api_key
|
||||
|
||||
# Only for Google as is needed on client setup
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
self.client = self._initialize_client()
|
||||
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[list[float]]:
|
||||
def _initialize_client(self) -> Any:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(
|
||||
self.model or DEFAULT_VERTEX_MODEL
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def encode(
|
||||
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
||||
) -> list[list[float]]:
|
||||
return [
|
||||
self.embed(text=text, text_type=text_type, model=model_name)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
def embed(
|
||||
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
||||
) -> list[float]:
|
||||
logger.debug(f"Embedding text with provider: {self.provider}")
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
# OpenAI does not seem to provide truncation option, however
|
||||
# the context lengths used by Danswer currently are smaller than the max token length
|
||||
# for OpenAI embeddings so it's not a big deal
|
||||
response = self.client.embeddings.create(input=texts, model=model)
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
response = self.client.embeddings.create(input=text, model=model)
|
||||
return response.data[0].embedding
|
||||
|
||||
def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = self.client.embed(
|
||||
texts=texts,
|
||||
texts=[text],
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
return response.embeddings
|
||||
return response.embeddings[0]
|
||||
|
||||
def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
# Similar to Cohere, the API server will do approximate size chunking
|
||||
# it's acceptable to miss by a few tokens
|
||||
response = self.client.embed(
|
||||
texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True, # Also this is default
|
||||
)
|
||||
return response.embeddings
|
||||
response = self.client.embed(text, model=model, input_type=embedding_type)
|
||||
return response.embeddings[0]
|
||||
|
||||
def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
embeddings = self.client.get_embeddings(
|
||||
embedding = self.client.get_embeddings(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # Also this is default
|
||||
]
|
||||
)
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
) -> list[list[float]]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
return embedding[0].values
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -234,52 +212,29 @@ def embed_text(
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
prefix: str | None,
|
||||
) -> list[list[float]]:
|
||||
# Third party API based embedding model
|
||||
if provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
if api_key is None:
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
if prefix:
|
||||
# This may change in the future if some providers require the user
|
||||
# to manually append a prefix but this is not the case currently
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
embeddings = cloud_model.encode(texts, model_name, text_type)
|
||||
|
||||
# Locally running model
|
||||
elif model_name is not None:
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
local_model = get_embedding_model(
|
||||
hosted_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
embeddings = hosted_model.encode(
|
||||
texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Failed to create Embeddings")
|
||||
raise RuntimeError("Embeddings were not created")
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -297,17 +252,7 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
|
||||
try:
|
||||
if embed_request.text_type == EmbedTextType.QUERY:
|
||||
prefix = embed_request.manual_query_prefix
|
||||
elif embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
prefix = embed_request.manual_passage_prefix
|
||||
else:
|
||||
prefix = None
|
||||
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
@@ -316,13 +261,13 @@ async def process_embed_request(
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
logger.exception(f"Error during embedding process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to run Bi-Encoder embedding"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
@@ -331,11 +276,6 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
if not embed_request.documents or not embed_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No documents or query to be reranked"
|
||||
)
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
fastapi==0.109.2
|
||||
h5py==3.9.0
|
||||
pydantic==1.10.13
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
tensorflow==2.15.0
|
||||
@@ -10,5 +9,5 @@ transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
openai==1.14.3
|
||||
cohere==5.6.1
|
||||
google-cloud-aiplatform==1.58.0
|
||||
cohere==5.5.8
|
||||
google-cloud-aiplatform==1.58.0
|
||||
@@ -4,7 +4,9 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
# This already includes any prefixes, the text is just passed directly to the model
|
||||
texts: list[str]
|
||||
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None
|
||||
max_context_length: int
|
||||
@@ -12,8 +14,6 @@ class EmbedRequest(BaseModel):
|
||||
api_key: str | None
|
||||
provider_type: str | None
|
||||
text_type: EmbedTextType
|
||||
manual_query_prefix: str | None
|
||||
manual_passage_prefix: str | None
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
|
||||
@@ -50,7 +50,7 @@ export default async function GalleryPage({
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets: documentSets,
|
||||
availableAssistants: assistants,
|
||||
availablePersonas: assistants,
|
||||
availableTags: tags,
|
||||
llmProviders,
|
||||
folders,
|
||||
|
||||
@@ -52,7 +52,7 @@ export default async function GalleryPage({
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets: documentSets,
|
||||
availableAssistants: assistants,
|
||||
availablePersonas: assistants,
|
||||
availableTags: tags,
|
||||
llmProviders,
|
||||
folders,
|
||||
|
||||
@@ -63,6 +63,7 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import Dropzone from "react-dropzone";
|
||||
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
|
||||
import { ChatInputBar } from "./input/ChatInputBar";
|
||||
import { ConfigurationModal } from "./modal/configuration/ConfigurationModal";
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { orderAssistantsForUser } from "@/lib/assistants/orderAssistants";
|
||||
@@ -81,29 +82,38 @@ const SYSTEM_MESSAGE_ID = -3;
|
||||
export function ChatPage({
|
||||
toggle,
|
||||
documentSidebarInitialWidth,
|
||||
defaultSelectedAssistantId,
|
||||
defaultSelectedPersonaId,
|
||||
toggledSidebar,
|
||||
}: {
|
||||
toggle: () => void;
|
||||
documentSidebarInitialWidth?: number;
|
||||
defaultSelectedAssistantId?: number;
|
||||
defaultSelectedPersonaId?: number;
|
||||
toggledSidebar: boolean;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const [configModalActiveTab, setConfigModalActiveTab] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
let {
|
||||
user,
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets,
|
||||
availableAssistants,
|
||||
availablePersonas,
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
} = useChatContext();
|
||||
|
||||
// chat session
|
||||
const filteredAssistants = orderAssistantsForUser(availablePersonas, user);
|
||||
|
||||
const [selectedAssistant, setSelectedAssistant] = useState<Persona | null>(
|
||||
null
|
||||
);
|
||||
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
|
||||
useState<Persona | null>(null);
|
||||
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const existingChatIdRaw = searchParams.get("chatId");
|
||||
const existingChatSessionId = existingChatIdRaw
|
||||
? parseInt(existingChatIdRaw)
|
||||
@@ -113,51 +123,9 @@ export function ChatPage({
|
||||
);
|
||||
const chatSessionIdRef = useRef<number | null>(existingChatSessionId);
|
||||
|
||||
// LLM
|
||||
const llmOverrideManager = useLlmOverride(selectedChatSession);
|
||||
|
||||
// Assistants
|
||||
const filteredAssistants = orderAssistantsForUser(availableAssistants, user);
|
||||
|
||||
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
|
||||
const [selectedAssistant, setSelectedAssistant] = useState<
|
||||
Persona | undefined
|
||||
>(
|
||||
// NOTE: look through available assistants here, so that even if the user
|
||||
// has hidden this assistant it still shows the correct assistant when
|
||||
// going back to an old chat session
|
||||
existingChatSessionAssistantId !== undefined
|
||||
? availableAssistants.find(
|
||||
(assistant) => assistant.id === existingChatSessionAssistantId
|
||||
)
|
||||
: defaultSelectedAssistantId !== undefined
|
||||
? availableAssistants.find(
|
||||
(assistant) => assistant.id === defaultSelectedAssistantId
|
||||
)
|
||||
: undefined
|
||||
);
|
||||
const setSelectedAssistantFromId = (assistantId: number) => {
|
||||
// NOTE: also intentionally look through available assistants here, so that
|
||||
// even if the user has hidden an assistant they can still go back to it
|
||||
// for old chats
|
||||
setSelectedAssistant(
|
||||
availableAssistants.find((assistant) => assistant.id === assistantId)
|
||||
);
|
||||
};
|
||||
const liveAssistant =
|
||||
selectedAssistant || filteredAssistants[0] || availableAssistants[0];
|
||||
|
||||
// this is for "@"ing assistants
|
||||
const [alternativeAssistant, setAlternativeAssistant] =
|
||||
useState<Persona | null>(null);
|
||||
|
||||
// this is used to track which assistant is being used to generate the current message
|
||||
// for example, this would come into play when:
|
||||
// 1. default assistant is `Danswer`
|
||||
// 2. we "@"ed the `GPT` assistant and sent a message
|
||||
// 3. while the `GPT` assistant message is generating, we "@" the `Paraphrase` assistant
|
||||
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
|
||||
useState<Persona | null>(null);
|
||||
const existingChatSessionPersonaId = selectedChatSession?.persona_id;
|
||||
|
||||
// used to track whether or not the initial "submit on load" has been performed
|
||||
// this only applies if `?submit-on-load=true` or `?submit-on-load=1` is in the URL
|
||||
@@ -214,10 +182,14 @@ export function ChatPage({
|
||||
async function initialSessionFetch() {
|
||||
if (existingChatSessionId === null) {
|
||||
setIsFetchingChatMessages(false);
|
||||
if (defaultSelectedAssistantId !== undefined) {
|
||||
setSelectedAssistantFromId(defaultSelectedAssistantId);
|
||||
if (defaultSelectedPersonaId !== undefined) {
|
||||
setSelectedPersona(
|
||||
filteredAssistants.find(
|
||||
(persona) => persona.id === defaultSelectedPersonaId
|
||||
)
|
||||
);
|
||||
} else {
|
||||
setSelectedAssistant(undefined);
|
||||
setSelectedPersona(undefined);
|
||||
}
|
||||
setCompleteMessageDetail({
|
||||
sessionId: null,
|
||||
@@ -242,7 +214,12 @@ export function ChatPage({
|
||||
);
|
||||
|
||||
const chatSession = (await response.json()) as BackendChatSession;
|
||||
setSelectedAssistantFromId(chatSession.persona_id);
|
||||
|
||||
setSelectedPersona(
|
||||
filteredAssistants.find(
|
||||
(persona) => persona.id === chatSession.persona_id
|
||||
)
|
||||
);
|
||||
|
||||
const newMessageMap = processRawChatHistory(chatSession.messages);
|
||||
const newMessageHistory = buildLatestMessageChain(newMessageMap);
|
||||
@@ -396,18 +373,32 @@ export function ChatPage({
|
||||
)
|
||||
: { aiMessage: null };
|
||||
|
||||
const [selectedPersona, setSelectedPersona] = useState<Persona | undefined>(
|
||||
existingChatSessionPersonaId !== undefined
|
||||
? filteredAssistants.find(
|
||||
(persona) => persona.id === existingChatSessionPersonaId
|
||||
)
|
||||
: defaultSelectedPersonaId !== undefined
|
||||
? filteredAssistants.find(
|
||||
(persona) => persona.id === defaultSelectedPersonaId
|
||||
)
|
||||
: undefined
|
||||
);
|
||||
const livePersona =
|
||||
selectedPersona || filteredAssistants[0] || availablePersonas[0];
|
||||
|
||||
const [chatSessionSharedStatus, setChatSessionSharedStatus] =
|
||||
useState<ChatSessionSharedStatus>(ChatSessionSharedStatus.Private);
|
||||
|
||||
useEffect(() => {
|
||||
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
|
||||
setSelectedAssistant(
|
||||
setSelectedPersona(
|
||||
filteredAssistants.find(
|
||||
(persona) => persona.id === defaultSelectedAssistantId
|
||||
(persona) => persona.id === defaultSelectedPersonaId
|
||||
)
|
||||
);
|
||||
}
|
||||
}, [defaultSelectedAssistantId]);
|
||||
}, [defaultSelectedPersonaId]);
|
||||
|
||||
const [
|
||||
selectedDocuments,
|
||||
@@ -423,7 +414,7 @@ export function ChatPage({
|
||||
useEffect(() => {
|
||||
async function fetchMaxTokens() {
|
||||
const response = await fetch(
|
||||
`/api/chat/max-selected-document-tokens?persona_id=${liveAssistant.id}`
|
||||
`/api/chat/max-selected-document-tokens?persona_id=${livePersona.id}`
|
||||
);
|
||||
if (response.ok) {
|
||||
const maxTokens = (await response.json()).max_tokens as number;
|
||||
@@ -432,12 +423,12 @@ export function ChatPage({
|
||||
}
|
||||
|
||||
fetchMaxTokens();
|
||||
}, [liveAssistant]);
|
||||
}, [livePersona]);
|
||||
|
||||
const filterManager = useFilters();
|
||||
const [finalAvailableSources, finalAvailableDocumentSets] =
|
||||
computeAvailableFilters({
|
||||
selectedPersona: selectedAssistant,
|
||||
selectedPersona,
|
||||
availableSources,
|
||||
availableDocumentSets,
|
||||
});
|
||||
@@ -633,16 +624,16 @@ export function ChatPage({
|
||||
queryOverride,
|
||||
forceSearch,
|
||||
isSeededChat,
|
||||
alternativeAssistantOverride = null,
|
||||
alternativeAssistant = null,
|
||||
}: {
|
||||
messageIdToResend?: number;
|
||||
messageOverride?: string;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
isSeededChat?: boolean;
|
||||
alternativeAssistantOverride?: Persona | null;
|
||||
alternativeAssistant?: Persona | null;
|
||||
} = {}) => {
|
||||
setAlternativeGeneratingAssistant(alternativeAssistantOverride);
|
||||
setAlternativeGeneratingAssistant(alternativeAssistant);
|
||||
|
||||
clientScrollToBottom();
|
||||
let currChatSessionId: number;
|
||||
@@ -652,7 +643,7 @@ export function ChatPage({
|
||||
|
||||
if (isNewSession) {
|
||||
currChatSessionId = await createChatSession(
|
||||
liveAssistant?.id || 0,
|
||||
livePersona?.id || 0,
|
||||
searchParamBasedChatSessionName
|
||||
);
|
||||
} else {
|
||||
@@ -730,9 +721,9 @@ export function ChatPage({
|
||||
parentMessage = frozenMessageMap.get(SYSTEM_MESSAGE_ID) || null;
|
||||
}
|
||||
|
||||
const currentAssistantId = alternativeAssistantOverride
|
||||
? alternativeAssistantOverride.id
|
||||
: alternativeAssistant?.id || liveAssistant.id;
|
||||
const currentAssistantId = alternativeAssistant
|
||||
? alternativeAssistant.id
|
||||
: selectedAssistant?.id;
|
||||
|
||||
resetInputBar();
|
||||
|
||||
@@ -760,7 +751,7 @@ export function ChatPage({
|
||||
fileDescriptors: currentMessageFiles,
|
||||
parentMessageId: lastSuccessfulMessageId,
|
||||
chatSessionId: currChatSessionId,
|
||||
promptId: liveAssistant?.prompts[0]?.id || 0,
|
||||
promptId: livePersona?.prompts[0]?.id || 0,
|
||||
filters: buildFilters(
|
||||
filterManager.selectedSources,
|
||||
filterManager.selectedDocumentSets,
|
||||
@@ -877,7 +868,7 @@ export function ChatPage({
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCalls: finalMessage?.tool_calls || toolCalls,
|
||||
parentMessageId: newUserMessageId,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
alternateAssistantID: selectedAssistant?.id,
|
||||
},
|
||||
]);
|
||||
}
|
||||
@@ -973,23 +964,19 @@ export function ChatPage({
|
||||
}
|
||||
};
|
||||
|
||||
const onAssistantChange = (assistant: Persona | null) => {
|
||||
if (assistant && assistant.id !== liveAssistant.id) {
|
||||
const onPersonaChange = (persona: Persona | null) => {
|
||||
if (persona && persona.id !== livePersona.id) {
|
||||
// remove uploaded files
|
||||
setCurrentMessageFiles([]);
|
||||
setSelectedAssistant(assistant);
|
||||
setSelectedPersona(persona);
|
||||
textAreaRef.current?.focus();
|
||||
router.push(buildChatUrl(searchParams, null, assistant.id));
|
||||
router.push(buildChatUrl(searchParams, null, persona.id));
|
||||
}
|
||||
};
|
||||
|
||||
const handleImageUpload = (acceptedFiles: File[]) => {
|
||||
const llmAcceptsImages = checkLLMSupportsImageInput(
|
||||
...getFinalLLM(
|
||||
llmProviders,
|
||||
liveAssistant,
|
||||
llmOverrideManager.llmOverride
|
||||
)
|
||||
...getFinalLLM(llmProviders, livePersona, llmOverrideManager.llmOverride)
|
||||
);
|
||||
const imageFiles = acceptedFiles.filter((file) =>
|
||||
file.type.startsWith("image/")
|
||||
@@ -1071,23 +1058,23 @@ export function ChatPage({
|
||||
useEffect(() => {
|
||||
const includes = checkAnyAssistantHasSearch(
|
||||
messageHistory,
|
||||
availableAssistants,
|
||||
liveAssistant
|
||||
availablePersonas,
|
||||
livePersona
|
||||
);
|
||||
setRetrievalEnabled(includes);
|
||||
}, [messageHistory, availableAssistants, liveAssistant]);
|
||||
}, [messageHistory, availablePersonas, livePersona]);
|
||||
|
||||
const [retrievalEnabled, setRetrievalEnabled] = useState(() => {
|
||||
return checkAnyAssistantHasSearch(
|
||||
messageHistory,
|
||||
availableAssistants,
|
||||
liveAssistant
|
||||
availablePersonas,
|
||||
livePersona
|
||||
);
|
||||
});
|
||||
|
||||
const innerSidebarElementRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const currentPersona = alternativeAssistant || liveAssistant;
|
||||
const currentPersona = selectedAssistant || livePersona;
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
@@ -1189,8 +1176,21 @@ export function ChatPage({
|
||||
/>
|
||||
)}
|
||||
|
||||
<ConfigurationModal
|
||||
chatSessionId={chatSessionIdRef.current!}
|
||||
activeTab={configModalActiveTab}
|
||||
setActiveTab={setConfigModalActiveTab}
|
||||
onClose={() => setConfigModalActiveTab(null)}
|
||||
filterManager={filterManager}
|
||||
availableAssistants={filteredAssistants}
|
||||
selectedAssistant={livePersona}
|
||||
setSelectedAssistant={onPersonaChange}
|
||||
llmProviders={llmProviders}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
/>
|
||||
|
||||
<div className="flex h-[calc(100dvh)] flex-col w-full">
|
||||
{liveAssistant && (
|
||||
{livePersona && (
|
||||
<FunctionalHeader
|
||||
page="chat"
|
||||
setSharingModalVisible={
|
||||
@@ -1239,7 +1239,7 @@ export function ChatPage({
|
||||
!isStreaming && (
|
||||
<ChatIntro
|
||||
availableSources={finalAvailableSources}
|
||||
selectedPersona={liveAssistant}
|
||||
selectedPersona={livePersona}
|
||||
/>
|
||||
)}
|
||||
<div
|
||||
@@ -1319,7 +1319,7 @@ export function ChatPage({
|
||||
|
||||
const currentAlternativeAssistant =
|
||||
message.alternateAssistantID != null
|
||||
? availableAssistants.find(
|
||||
? availablePersonas.find(
|
||||
(persona) =>
|
||||
persona.id ==
|
||||
message.alternateAssistantID
|
||||
@@ -1342,7 +1342,7 @@ export function ChatPage({
|
||||
toggleDocumentSelectionAspects
|
||||
}
|
||||
docs={message.documents}
|
||||
currentPersona={liveAssistant}
|
||||
currentPersona={livePersona}
|
||||
alternativeAssistant={
|
||||
currentAlternativeAssistant
|
||||
}
|
||||
@@ -1352,7 +1352,7 @@ export function ChatPage({
|
||||
query={
|
||||
messageHistory[i]?.query || undefined
|
||||
}
|
||||
personaName={liveAssistant.name}
|
||||
personaName={livePersona.name}
|
||||
citedDocuments={getCitedDocumentsFromMessage(
|
||||
message
|
||||
)}
|
||||
@@ -1404,7 +1404,7 @@ export function ChatPage({
|
||||
messageIdToResend:
|
||||
previousMessage.messageId,
|
||||
queryOverride: newQuery,
|
||||
alternativeAssistantOverride:
|
||||
alternativeAssistant:
|
||||
currentAlternativeAssistant,
|
||||
});
|
||||
}
|
||||
@@ -1435,7 +1435,7 @@ export function ChatPage({
|
||||
messageIdToResend:
|
||||
previousMessage.messageId,
|
||||
forceSearch: true,
|
||||
alternativeAssistantOverride:
|
||||
alternativeAssistant:
|
||||
currentAlternativeAssistant,
|
||||
});
|
||||
} else {
|
||||
@@ -1460,9 +1460,9 @@ export function ChatPage({
|
||||
return (
|
||||
<div key={messageReactComponentKey}>
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
currentPersona={livePersona}
|
||||
messageId={message.messageId}
|
||||
personaName={liveAssistant.name}
|
||||
personaName={livePersona.name}
|
||||
content={
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{message.message}
|
||||
@@ -1481,13 +1481,13 @@ export function ChatPage({
|
||||
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
|
||||
>
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
currentPersona={livePersona}
|
||||
alternativeAssistant={
|
||||
alternativeGeneratingAssistant ??
|
||||
alternativeAssistant
|
||||
selectedAssistant
|
||||
}
|
||||
messageId={null}
|
||||
personaName={liveAssistant.name}
|
||||
personaName={livePersona.name}
|
||||
content={
|
||||
<div className="text-sm my-auto">
|
||||
<ThreeDots
|
||||
@@ -1513,7 +1513,7 @@ export function ChatPage({
|
||||
{currentPersona &&
|
||||
currentPersona.starter_messages &&
|
||||
currentPersona.starter_messages.length > 0 &&
|
||||
selectedAssistant &&
|
||||
selectedPersona &&
|
||||
messageHistory.length === 0 &&
|
||||
!isFetchingChatMessages && (
|
||||
<div
|
||||
@@ -1570,25 +1570,32 @@ export function ChatPage({
|
||||
<ChatInputBar
|
||||
showDocs={() => setDocumentSelection(true)}
|
||||
selectedDocuments={selectedDocuments}
|
||||
// assistant stuff
|
||||
assistantOptions={filteredAssistants}
|
||||
selectedAssistant={liveAssistant}
|
||||
setSelectedAssistant={onAssistantChange}
|
||||
setAlternativeAssistant={setAlternativeAssistant}
|
||||
alternativeAssistant={alternativeAssistant}
|
||||
// end assistant stuff
|
||||
setSelectedAssistant={onPersonaChange}
|
||||
onSetSelectedAssistant={(
|
||||
alternativeAssistant: Persona | null
|
||||
) => {
|
||||
setSelectedAssistant(alternativeAssistant);
|
||||
}}
|
||||
alternativeAssistant={selectedAssistant}
|
||||
personas={filteredAssistants}
|
||||
message={message}
|
||||
setMessage={setMessage}
|
||||
onSubmit={onSubmit}
|
||||
isStreaming={isStreaming}
|
||||
setIsCancelled={setIsCancelled}
|
||||
retrievalDisabled={
|
||||
!personaIncludesRetrieval(currentPersona)
|
||||
}
|
||||
filterManager={filterManager}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
selectedAssistant={livePersona}
|
||||
files={currentMessageFiles}
|
||||
setFiles={setCurrentMessageFiles}
|
||||
handleFileUpload={handleImageUpload}
|
||||
setConfigModalActiveTab={setConfigModalActiveTab}
|
||||
textAreaRef={textAreaRef}
|
||||
chatSessionId={chatSessionIdRef.current!}
|
||||
availableAssistants={availablePersonas}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -5,10 +5,10 @@ import { ChatPage } from "./ChatPage";
|
||||
import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper";
|
||||
|
||||
export default function WrappedChat({
|
||||
defaultAssistantId,
|
||||
defaultPersonaId,
|
||||
initiallyToggled,
|
||||
}: {
|
||||
defaultAssistantId?: number;
|
||||
defaultPersonaId?: number;
|
||||
initiallyToggled: boolean;
|
||||
}) {
|
||||
return (
|
||||
@@ -17,7 +17,7 @@ export default function WrappedChat({
|
||||
content={(toggledSidebar, toggle) => (
|
||||
<ChatPage
|
||||
toggle={toggle}
|
||||
defaultSelectedAssistantId={defaultAssistantId}
|
||||
defaultSelectedPersonaId={defaultPersonaId}
|
||||
toggledSidebar={toggledSidebar}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -21,6 +21,7 @@ import { IconType } from "react-icons";
|
||||
import Popup from "../../../components/popup/Popup";
|
||||
import { LlmTab } from "../modal/configuration/LlmTab";
|
||||
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
|
||||
import ChatInputAssistant from "./ChatInputAssistant";
|
||||
import { DanswerDocument } from "@/lib/search/interfaces";
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
@@ -28,6 +29,7 @@ import { Hoverable } from "@/components/Hoverable";
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
export function ChatInputBar({
|
||||
personas,
|
||||
showDocs,
|
||||
selectedDocuments,
|
||||
message,
|
||||
@@ -35,32 +37,34 @@ export function ChatInputBar({
|
||||
onSubmit,
|
||||
isStreaming,
|
||||
setIsCancelled,
|
||||
retrievalDisabled,
|
||||
filterManager,
|
||||
llmOverrideManager,
|
||||
|
||||
// assistants
|
||||
onSetSelectedAssistant,
|
||||
selectedAssistant,
|
||||
assistantOptions,
|
||||
setSelectedAssistant,
|
||||
setAlternativeAssistant,
|
||||
|
||||
files,
|
||||
|
||||
setSelectedAssistant,
|
||||
setFiles,
|
||||
handleFileUpload,
|
||||
setConfigModalActiveTab,
|
||||
textAreaRef,
|
||||
alternativeAssistant,
|
||||
chatSessionId,
|
||||
availableAssistants,
|
||||
}: {
|
||||
showDocs: () => void;
|
||||
selectedDocuments: DanswerDocument[];
|
||||
assistantOptions: Persona[];
|
||||
setAlternativeAssistant: (alternativeAssistant: Persona | null) => void;
|
||||
availableAssistants: Persona[];
|
||||
onSetSelectedAssistant: (alternativeAssistant: Persona | null) => void;
|
||||
setSelectedAssistant: (assistant: Persona) => void;
|
||||
personas: Persona[];
|
||||
message: string;
|
||||
setMessage: (message: string) => void;
|
||||
onSubmit: () => void;
|
||||
isStreaming: boolean;
|
||||
setIsCancelled: (value: boolean) => void;
|
||||
retrievalDisabled: boolean;
|
||||
filterManager: FilterManager;
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
selectedAssistant: Persona;
|
||||
@@ -68,6 +72,7 @@ export function ChatInputBar({
|
||||
files: FileDescriptor[];
|
||||
setFiles: (files: FileDescriptor[]) => void;
|
||||
handleFileUpload: (files: File[]) => void;
|
||||
setConfigModalActiveTab: (tab: string) => void;
|
||||
textAreaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
chatSessionId?: number;
|
||||
}) {
|
||||
@@ -131,10 +136,8 @@ export function ChatInputBar({
|
||||
};
|
||||
|
||||
// Update selected persona
|
||||
const updatedTaggedAssistant = (assistant: Persona) => {
|
||||
setAlternativeAssistant(
|
||||
assistant.id == selectedAssistant.id ? null : assistant
|
||||
);
|
||||
const updateCurrentPersona = (persona: Persona) => {
|
||||
onSetSelectedAssistant(persona.id == selectedAssistant.id ? null : persona);
|
||||
hideSuggestions();
|
||||
setMessage("");
|
||||
};
|
||||
@@ -157,8 +160,8 @@ export function ChatInputBar({
|
||||
}
|
||||
};
|
||||
|
||||
const assistantTagOptions = assistantOptions.filter((assistant) =>
|
||||
assistant.name.toLowerCase().startsWith(
|
||||
const filteredPersonas = personas.filter((persona) =>
|
||||
persona.name.toLowerCase().startsWith(
|
||||
message
|
||||
.slice(message.lastIndexOf("@") + 1)
|
||||
.split(/\s/)[0]
|
||||
@@ -171,18 +174,18 @@ export function ChatInputBar({
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (
|
||||
showSuggestions &&
|
||||
assistantTagOptions.length > 0 &&
|
||||
filteredPersonas.length > 0 &&
|
||||
(e.key === "Tab" || e.key == "Enter")
|
||||
) {
|
||||
e.preventDefault();
|
||||
if (assistantIconIndex == assistantTagOptions.length) {
|
||||
if (assistantIconIndex == filteredPersonas.length) {
|
||||
window.open("/assistants/new", "_blank");
|
||||
hideSuggestions();
|
||||
setMessage("");
|
||||
} else {
|
||||
const option =
|
||||
assistantTagOptions[assistantIconIndex >= 0 ? assistantIconIndex : 0];
|
||||
updatedTaggedAssistant(option);
|
||||
filteredPersonas[assistantIconIndex >= 0 ? assistantIconIndex : 0];
|
||||
updateCurrentPersona(option);
|
||||
}
|
||||
}
|
||||
if (!showSuggestions) {
|
||||
@@ -192,7 +195,7 @@ export function ChatInputBar({
|
||||
if (e.key === "ArrowDown") {
|
||||
e.preventDefault();
|
||||
setAssistantIconIndex((assistantIconIndex) =>
|
||||
Math.min(assistantIconIndex + 1, assistantTagOptions.length)
|
||||
Math.min(assistantIconIndex + 1, filteredPersonas.length)
|
||||
);
|
||||
} else if (e.key === "ArrowUp") {
|
||||
e.preventDefault();
|
||||
@@ -216,36 +219,35 @@ export function ChatInputBar({
|
||||
mx-auto
|
||||
"
|
||||
>
|
||||
{showSuggestions && assistantTagOptions.length > 0 && (
|
||||
{showSuggestions && filteredPersonas.length > 0 && (
|
||||
<div
|
||||
ref={suggestionsRef}
|
||||
className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full"
|
||||
>
|
||||
<div className="rounded-lg py-1.5 bg-background border border-border-medium shadow-lg mx-2 px-1.5 mt-2 rounded z-10">
|
||||
{assistantTagOptions.map((currentAssistant, index) => (
|
||||
{filteredPersonas.map((currentPersona, index) => (
|
||||
<button
|
||||
key={index}
|
||||
className={`px-2 ${
|
||||
assistantIconIndex == index && "bg-hover-lightish"
|
||||
} rounded rounded-lg content-start flex gap-x-1 py-2 w-full hover:bg-hover-lightish cursor-pointer`}
|
||||
onClick={() => {
|
||||
updatedTaggedAssistant(currentAssistant);
|
||||
updateCurrentPersona(currentPersona);
|
||||
}}
|
||||
>
|
||||
<p className="font-bold">{currentAssistant.name}</p>
|
||||
<p className="font-bold">{currentPersona.name}</p>
|
||||
<p className="line-clamp-1">
|
||||
{currentAssistant.id == selectedAssistant.id &&
|
||||
{currentPersona.id == selectedAssistant.id &&
|
||||
"(default) "}
|
||||
{currentAssistant.description}
|
||||
{currentPersona.description}
|
||||
</p>
|
||||
</button>
|
||||
))}
|
||||
<a
|
||||
key={assistantTagOptions.length}
|
||||
key={filteredPersonas.length}
|
||||
target="_blank"
|
||||
className={`${
|
||||
assistantIconIndex == assistantTagOptions.length &&
|
||||
"bg-hover"
|
||||
assistantIconIndex == filteredPersonas.length && "bg-hover"
|
||||
} rounded rounded-lg px-3 flex gap-x-1 py-2 w-full items-center hover:bg-hover-lightish cursor-pointer"`}
|
||||
href="/assistants/new"
|
||||
>
|
||||
@@ -299,7 +301,7 @@ export function ChatInputBar({
|
||||
|
||||
<Hoverable
|
||||
icon={FiX}
|
||||
onClick={() => setAlternativeAssistant(null)}
|
||||
onClick={() => onSetSelectedAssistant(null)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -407,7 +409,7 @@ export function ChatInputBar({
|
||||
removePadding
|
||||
content={(close) => (
|
||||
<AssistantsTab
|
||||
availableAssistants={assistantOptions}
|
||||
availableAssistants={availableAssistants}
|
||||
llmProviders={llmProviders}
|
||||
selectedAssistant={selectedAssistant}
|
||||
onSelect={(assistant) => {
|
||||
|
||||
@@ -505,7 +505,7 @@ export function removeMessage(
|
||||
|
||||
export function checkAnyAssistantHasSearch(
|
||||
messageHistory: Message[],
|
||||
availableAssistants: Persona[],
|
||||
availablePersonas: Persona[],
|
||||
livePersona: Persona
|
||||
): boolean {
|
||||
const response =
|
||||
@@ -516,8 +516,8 @@ export function checkAnyAssistantHasSearch(
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
const alternateAssistant = availableAssistants.find(
|
||||
(assistant) => assistant.id === message.alternateAssistantID
|
||||
const alternateAssistant = availablePersonas.find(
|
||||
(persona) => persona.id === message.alternateAssistantID
|
||||
);
|
||||
return alternateAssistant
|
||||
? personaIncludesRetrieval(alternateAssistant)
|
||||
|
||||
@@ -798,7 +798,7 @@ export const HumanMessage = ({
|
||||
!isEditing &&
|
||||
(!files || files.length === 0)
|
||||
) && "ml-auto"
|
||||
} relative max-w-[70%] mb-auto whitespace-break-spaces rounded-3xl bg-user px-5 py-2.5`}
|
||||
} relative max-w-[70%] mb-auto rounded-3xl bg-user px-5 py-2.5`}
|
||||
>
|
||||
{content}
|
||||
</div>
|
||||
|
||||
@@ -62,10 +62,11 @@ export function AssistantsTab({
|
||||
toolName = "Image Generation";
|
||||
toolIcon = <FiImage className="mr-1 my-auto" />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Bubble key={tool.id} isSelected={false}>
|
||||
<div className="flex line-wrap break-all flex-row gap-1">
|
||||
<div className="flex-none my-auto">{toolIcon}</div>
|
||||
<div className="flex flex-row gap-1">
|
||||
{toolIcon}
|
||||
{toolName}
|
||||
</div>
|
||||
</Bubble>
|
||||
|
||||
180
web/src/app/chat/modal/configuration/ConfigurationModal.tsx
Normal file
180
web/src/app/chat/modal/configuration/ConfigurationModal.tsx
Normal file
@@ -0,0 +1,180 @@
|
||||
"use client";
|
||||
|
||||
import React, { useEffect } from "react";
|
||||
import { Modal } from "../../../../components/Modal";
|
||||
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
|
||||
import { FiltersTab } from "./FiltersTab";
|
||||
import { FiCpu, FiFilter, FiX } from "react-icons/fi";
|
||||
import { IconType } from "react-icons";
|
||||
import { FaBrain } from "react-icons/fa";
|
||||
import { AssistantsTab } from "./AssistantsTab";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { LlmTab } from "./LlmTab";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
|
||||
import { AssistantsIcon, IconProps } from "@/components/icons/icons";
|
||||
|
||||
const TabButton = ({
|
||||
label,
|
||||
icon: Icon,
|
||||
isActive,
|
||||
onClick,
|
||||
}: {
|
||||
label: string;
|
||||
icon: IconType;
|
||||
isActive: boolean;
|
||||
onClick: () => void;
|
||||
}) => (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className={`
|
||||
pb-4
|
||||
pt-6
|
||||
px-2
|
||||
text-emphasis
|
||||
font-bold
|
||||
${isActive ? "border-b-2 border-accent" : ""}
|
||||
hover:bg-hover-light
|
||||
hover:text-strong
|
||||
transition
|
||||
duration-200
|
||||
ease-in-out
|
||||
flex
|
||||
`}
|
||||
>
|
||||
<Icon className="inline-block mr-2 my-auto" size="16" />
|
||||
|
||||
<p className="my-auto">{label}</p>
|
||||
</button>
|
||||
);
|
||||
|
||||
export function ConfigurationModal({
|
||||
activeTab,
|
||||
setActiveTab,
|
||||
onClose,
|
||||
availableAssistants,
|
||||
selectedAssistant,
|
||||
setSelectedAssistant,
|
||||
filterManager,
|
||||
llmProviders,
|
||||
llmOverrideManager,
|
||||
chatSessionId,
|
||||
}: {
|
||||
activeTab: string | null;
|
||||
setActiveTab: (tab: string | null) => void;
|
||||
onClose: () => void;
|
||||
availableAssistants: Persona[];
|
||||
selectedAssistant: Persona;
|
||||
setSelectedAssistant: (assistant: Persona) => void;
|
||||
filterManager: FilterManager;
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
chatSessionId?: number;
|
||||
}) {
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === "Escape") {
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("keydown", handleKeyDown);
|
||||
return () => {
|
||||
document.removeEventListener("keydown", handleKeyDown);
|
||||
};
|
||||
}, [onClose]);
|
||||
|
||||
if (!activeTab) return null;
|
||||
|
||||
return (
|
||||
<Modal
|
||||
onOutsideClick={onClose}
|
||||
noPadding
|
||||
className="
|
||||
w-4/6
|
||||
h-4/6
|
||||
flex
|
||||
flex-col
|
||||
"
|
||||
>
|
||||
<div className="rounded flex flex-col overflow-hidden">
|
||||
<div className="mb-4">
|
||||
<div className="flex border-b border-border bg-background-emphasis">
|
||||
<div className="flex px-6 gap-x-2">
|
||||
<TabButton
|
||||
label="Assistants"
|
||||
icon={FaBrain}
|
||||
isActive={activeTab === "assistants"}
|
||||
onClick={() => setActiveTab("assistants")}
|
||||
/>
|
||||
<TabButton
|
||||
label="Models"
|
||||
icon={FiCpu}
|
||||
isActive={activeTab === "llms"}
|
||||
onClick={() => setActiveTab("llms")}
|
||||
/>
|
||||
<TabButton
|
||||
label="Filters"
|
||||
icon={FiFilter}
|
||||
isActive={activeTab === "filters"}
|
||||
onClick={() => setActiveTab("filters")}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
className="
|
||||
ml-auto
|
||||
px-1
|
||||
py-1
|
||||
text-xs
|
||||
font-medium
|
||||
rounded
|
||||
hover:bg-hover
|
||||
focus:outline-none
|
||||
focus:ring-2
|
||||
focus:ring-offset-2
|
||||
focus:ring-subtle
|
||||
flex
|
||||
items-center
|
||||
h-fit
|
||||
my-auto
|
||||
mr-5
|
||||
"
|
||||
onClick={onClose}
|
||||
>
|
||||
<FiX size={24} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col overflow-y-auto">
|
||||
<div className="px-8 pt-4">
|
||||
{activeTab === "filters" && (
|
||||
<FiltersTab filterManager={filterManager} />
|
||||
)}
|
||||
|
||||
{activeTab === "llms" && (
|
||||
<LlmTab
|
||||
chatSessionId={chatSessionId}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
currentAssistant={selectedAssistant}
|
||||
/>
|
||||
)}
|
||||
|
||||
{activeTab === "assistants" && (
|
||||
<div>
|
||||
<AssistantsTab
|
||||
availableAssistants={availableAssistants}
|
||||
llmProviders={llmProviders}
|
||||
selectedAssistant={selectedAssistant}
|
||||
onSelect={(assistant) => {
|
||||
setSelectedAssistant(assistant);
|
||||
onClose();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -35,7 +35,7 @@ export default async function Page({
|
||||
folders,
|
||||
toggleSidebar,
|
||||
openedFolders,
|
||||
defaultAssistantId,
|
||||
defaultPersonaId,
|
||||
finalDocumentSidebarInitialWidth,
|
||||
shouldShowWelcomeModal,
|
||||
shouldDisplaySourcesIncompleteModal,
|
||||
@@ -58,7 +58,7 @@ export default async function Page({
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets: documentSets,
|
||||
availableAssistants: assistants,
|
||||
availablePersonas: assistants,
|
||||
availableTags: tags,
|
||||
llmProviders,
|
||||
folders,
|
||||
@@ -66,7 +66,7 @@ export default async function Page({
|
||||
}}
|
||||
>
|
||||
<WrappedChat
|
||||
defaultAssistantId={defaultAssistantId}
|
||||
defaultPersonaId={defaultPersonaId}
|
||||
initiallyToggled={toggleSidebar}
|
||||
/>
|
||||
</ChatProvider>
|
||||
|
||||
111
web/src/app/chat/sessionSidebar/AssistantsTab.tsx
Normal file
111
web/src/app/chat/sessionSidebar/AssistantsTab.tsx
Normal file
@@ -0,0 +1,111 @@
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { BasicSelectable } from "@/components/BasicClickable";
|
||||
import { AssistantsIcon } from "@/components/icons/icons";
|
||||
import { User } from "@/lib/types";
|
||||
import { Text } from "@tremor/react";
|
||||
import Link from "next/link";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
import { FiEdit2 } from "react-icons/fi";
|
||||
|
||||
function AssistantDisplay({
|
||||
persona,
|
||||
onSelect,
|
||||
user,
|
||||
}: {
|
||||
persona: Persona;
|
||||
onSelect: (persona: Persona) => void;
|
||||
user: User | null;
|
||||
}) {
|
||||
const isEditable =
|
||||
(!user || user.id === persona.owner?.id) &&
|
||||
!persona.default_persona &&
|
||||
(!persona.is_public || !user || user.role === "admin");
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<div className="w-full" onClick={() => onSelect(persona)}>
|
||||
<BasicSelectable selected={false} fullWidth>
|
||||
<div className="flex">
|
||||
<div className="truncate w-48 3xl:w-56 flex">
|
||||
<AssistantsIcon className="mr-2 my-auto" size={16} />{" "}
|
||||
{persona.name}
|
||||
</div>
|
||||
</div>
|
||||
</BasicSelectable>
|
||||
</div>
|
||||
{isEditable && (
|
||||
<div className="pl-2 my-auto">
|
||||
<Link href={`/assistants/edit/${persona.id}`}>
|
||||
<FiEdit2
|
||||
className="my-auto ml-auto hover:bg-hover p-0.5"
|
||||
size={20}
|
||||
/>
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function AssistantsTab({
|
||||
personas,
|
||||
onPersonaChange,
|
||||
user,
|
||||
}: {
|
||||
personas: Persona[];
|
||||
onPersonaChange: (persona: Persona | null) => void;
|
||||
user: User | null;
|
||||
}) {
|
||||
const globalAssistants = personas.filter((persona) => persona.is_public);
|
||||
const personalAssistants = personas.filter(
|
||||
(persona) =>
|
||||
(!user || persona.users.some((u) => u.id === user.id)) &&
|
||||
!persona.is_public
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mt-4 pb-1 overflow-y-auto h-full flex flex-col gap-y-1">
|
||||
<Text className="mx-3 text-xs mb-4">
|
||||
Select an Assistant below to begin a new chat with them!
|
||||
</Text>
|
||||
|
||||
<div className="mx-3">
|
||||
{globalAssistants.length > 0 && (
|
||||
<>
|
||||
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 font-bold">
|
||||
Global
|
||||
</div>
|
||||
{globalAssistants.map((persona) => {
|
||||
return (
|
||||
<AssistantDisplay
|
||||
key={persona.id}
|
||||
persona={persona}
|
||||
onSelect={onPersonaChange}
|
||||
user={user}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
{personalAssistants.length > 0 && (
|
||||
<>
|
||||
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 mt-5 font-bold">
|
||||
Personal
|
||||
</div>
|
||||
{personalAssistants.map((persona) => {
|
||||
return (
|
||||
<AssistantDisplay
|
||||
key={persona.id}
|
||||
persona={persona}
|
||||
onSelect={onPersonaChange}
|
||||
user={user}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -11,7 +11,7 @@ export default function FixedLogo() {
|
||||
const enterpriseSettings = combinedSettings?.enterpriseSettings;
|
||||
|
||||
return (
|
||||
<div className="absolute flex z-40 left-4 top-2">
|
||||
<div className="fixed flex z-40 left-4 top-2">
|
||||
{" "}
|
||||
<a href="/chat" className="ml-7 text-text-700 text-xl">
|
||||
<div>
|
||||
|
||||
@@ -18,7 +18,7 @@ interface ChatContextProps {
|
||||
chatSessions: ChatSession[];
|
||||
availableSources: ValidSources[];
|
||||
availableDocumentSets: DocumentSet[];
|
||||
availableAssistants: Persona[];
|
||||
availablePersonas: Persona[];
|
||||
availableTags: Tag[];
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
folders: Folder[];
|
||||
|
||||
@@ -19,7 +19,7 @@ export function Citation({
|
||||
return (
|
||||
<CustomTooltip
|
||||
citation
|
||||
content={<div className="inline-block p-0 m-0 truncate">{link}</div>}
|
||||
content={<p className="inline-block p-0 m-0 truncate">{link}</p>}
|
||||
>
|
||||
<a
|
||||
onClick={() => (link ? window.open(link, "_blank") : undefined)}
|
||||
|
||||
@@ -30,6 +30,7 @@ import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
|
||||
interface FetchChatDataResult {
|
||||
user: User | null;
|
||||
chatSessions: ChatSession[];
|
||||
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
availableSources: ValidSources[];
|
||||
documentSets: DocumentSet[];
|
||||
@@ -38,7 +39,7 @@ interface FetchChatDataResult {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
folders: Folder[];
|
||||
openedFolders: Record<string, boolean>;
|
||||
defaultAssistantId?: number;
|
||||
defaultPersonaId?: number;
|
||||
toggleSidebar: boolean;
|
||||
finalDocumentSidebarInitialWidth?: number;
|
||||
shouldShowWelcomeModal: boolean;
|
||||
@@ -149,9 +150,9 @@ export async function fetchChatData(searchParams: {
|
||||
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
|
||||
}
|
||||
|
||||
const defaultAssistantIdRaw = searchParams["assistantId"];
|
||||
const defaultAssistantId = defaultAssistantIdRaw
|
||||
? parseInt(defaultAssistantIdRaw)
|
||||
const defaultPersonaIdRaw = searchParams["assistantId"];
|
||||
const defaultPersonaId = defaultPersonaIdRaw
|
||||
? parseInt(defaultPersonaIdRaw)
|
||||
: undefined;
|
||||
|
||||
const documentSidebarCookieInitialWidth = cookies().get(
|
||||
@@ -208,7 +209,7 @@ export async function fetchChatData(searchParams: {
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
defaultAssistantId,
|
||||
defaultPersonaId,
|
||||
finalDocumentSidebarInitialWidth,
|
||||
toggleSidebar,
|
||||
shouldShowWelcomeModal,
|
||||
|
||||
Reference in New Issue
Block a user