Compare commits

...

21 Commits

Author SHA1 Message Date
pablodanswer
4029233df0 hide incomplete sources for non-admins (#1901) 2024-07-22 13:40:11 -07:00
hagen-danswer
6c88c0156c Added file upload retry logic (#1889) 2024-07-22 13:13:22 -07:00
pablodanswer
33332d08f2 fix citation title (#1900)
* fix citation title

* remove title function
2024-07-22 17:37:04 +00:00
hagen-danswer
17005fb705 switched default pruning behavior and removed some logging (#1898) 2024-07-22 17:36:26 +00:00
hagen-danswer
48a7fe80b1 Committed LLM updates to db (#1899) 2024-07-22 10:30:24 -07:00
pablodanswer
1276732409 Misc bug fixes (#1895) 2024-07-22 10:22:43 -07:00
Weves
f91b92a898 Make is_public default true for LLMProvider 2024-07-21 22:22:37 -07:00
Weves
6222f533be Update force delete script to handle user groups 2024-07-21 22:22:37 -07:00
hagen-danswer
1b49d17239 Added ability to control LLM access based on group (#1870)
* Added ability to control LLM access based on group

* completed relationship deletion

* cleaned up function

* added comments

* fixed frontend strings

* mypy fixes

* added case handling for deletion of user groups

* hidden advanced options now

* removed unnecessary code
2024-07-22 04:31:44 +00:00
Yuhong Sun
2f5f19642e Double Check Max Tokens for Indexing (#1893) 2024-07-21 21:12:39 -07:00
Yuhong Sun
6db4634871 Token Truncation (#1892) 2024-07-21 16:26:32 -07:00
Yuhong Sun
5cfed45cef Handle Empty Titles (#1891) 2024-07-21 14:59:23 -07:00
Weves
581ffde35a Fix jira connector failures for server deployments 2024-07-21 14:44:25 -07:00
pablodanswer
6313e6d91d Remove visit api when unneded (#1885)
* quick fix to test on ec2

* quick cleanup

* modify a name

* address full doc as well

* additional timing info + handling

* clean up

* squash

* Print only
2024-07-21 20:57:24 +00:00
Weves
c09c94bf32 Fix assistant swap 2024-07-21 13:57:36 -07:00
Yuhong Sun
0e8ba111c8 Model Touchups (#1887) 2024-07-21 12:31:00 -07:00
Yuhong Sun
2ba24b1734 Reenable Search Pipeline (#1886) 2024-07-21 10:33:29 -07:00
Yuhong Sun
44820b4909 k 2024-07-21 10:27:57 -07:00
hagen-danswer
eb3e7610fc Added retries and multithreading for cloud embedding (#1879)
* added retries and multithreading for cloud embedding

* refactored a bit

* cleaned up code

* got the errors to bubble up to the ui correctly

* added exceptin printing

* added requirements

* touchups

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-07-20 22:10:18 -07:00
pablodanswer
7fbbb174bb minor fixes (#1882)
- Assistants tab size
- Fixed logo -> absolute
2024-07-20 21:02:57 -07:00
pablodanswer
3854ca11af add newlines for message content 2024-07-20 18:57:29 -07:00
76 changed files with 1355 additions and 1088 deletions

View File

@@ -0,0 +1,41 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -80,7 +80,7 @@ def should_prune_cc_pair(
return True
return False
if PREVENT_SIMULTANEOUS_PRUNING:
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
@@ -89,11 +89,9 @@ def should_prune_cc_pair(
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
if not last_pruning_task.start_time:

View File

@@ -33,7 +33,7 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable

View File

@@ -51,7 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
@@ -253,7 +253,6 @@ def stream_chat_message_objects(
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
try:
user_id = user.id if user is not None else None

View File

@@ -214,8 +214,8 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.

View File

@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
continue
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments]
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
pass
metadata_dict = {}
if jira.fields.priority:
metadata_dict["priority"] = jira.fields.priority.name
if jira.fields.status:
metadata_dict["status"] = jira.fields.status.name
if jira.fields.resolution:
metadata_dict["resolution"] = jira.fields.resolution.name
if jira.fields.labels:
metadata_dict["label"] = jira.fields.labels
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
doc_batch.append(
Document(

View File

@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=100, period=60)
@rate_limit_builder(max_calls=50, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")

View File

@@ -123,6 +123,8 @@ class DocumentBase(BaseModel):
for char in replace_chars:
title = title.replace(char, " ")
title = title.strip()
# Title could be quite long here as there is no truncation done
# just prior to embedding, it could be truncated
return title
def get_metadata_str_attributes(self) -> list[str] | None:

View File

@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST

View File

@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)

View File

@@ -1,15 +1,41 @@
from sqlalchemy import delete
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def update_group_llm_provider_relationships__no_commit(
llm_provider_id: int,
group_ids: list[int] | None,
db_session: Session,
) -> None:
# Delete existing relationships
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
).delete(synchronize_session="fetch")
# Add new relationships from given group_ids
if group_ids:
new_relationships = [
LLMProvider__UserGroup(
llm_provider_id=llm_provider_id,
user_group_id=group_id,
)
for group_id in group_ids
]
db_session.add_all(new_relationships)
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
@@ -36,36 +62,35 @@ def upsert_llm_provider(
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider:
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = (
llm_provider.fast_default_model_name
)
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
default_model_name=llm_provider.default_model_name,
fast_default_model_name=llm_provider.fast_default_model_name,
model_names=llm_provider.model_names,
is_default_provider=None,
if not existing_llm_provider:
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
group_ids=llm_provider.groups,
db_session=db_session,
)
db_session.add(llm_provider_model)
db_session.commit()
return FullLLMProvider.from_model(llm_provider_model)
return FullLLMProvider.from_model(existing_llm_provider)
def fetch_existing_embedding_providers(
@@ -74,8 +99,29 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_subquery = (
select(User__UserGroup.user_group_id)
.where(User__UserGroup.user_id == user.id)
.subquery()
)
access_conditions = or_(
LLMProviderModel.is_public,
LLMProviderModel.id.in_( # User is part of a group that has access
select(LLMProvider__UserGroup.llm_provider_id).where(
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
)
),
)
stmt = stmt.where(access_conditions)
return list(db_session.scalars(stmt).all())
def fetch_embedding_provider(
@@ -119,6 +165,13 @@ def remove_embedding_provider(
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships
db_session.execute(
delete(LLMProvider__UserGroup).where(
LLMProvider__UserGroup.llm_provider_id == provider_id
)
)
# Remove LLMProvider
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)

View File

@@ -932,6 +932,13 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="llm_provider__user_group",
viewonly=True,
)
class CloudEmbeddingProvider(Base):
@@ -1109,7 +1116,6 @@ class Persona(Base):
# where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# These are only defaults, users can select from all if desired
prompts: Mapped[list[Prompt]] = relationship(
@@ -1137,6 +1143,7 @@ class Persona(Base):
viewonly=True,
)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="persona__user_group",
@@ -1360,6 +1367,17 @@ class Persona__UserGroup(Base):
)
class LLMProvider__UserGroup(Base):
__tablename__ = "llm_provider__user_group"
llm_provider_id: Mapped[int] = mapped_column(
ForeignKey("llm_provider.id"), primary_key=True
)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
class DocumentSet__UserGroup(Base):
__tablename__ = "document_set__user_group"

View File

@@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
default_model_name=default_model_name,
fast_default_model_name=default_fast_model_name,
model_names=None,
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
update_default_provider(db_session, llm_provider.id)

View File

@@ -15,7 +15,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
)
from danswer.connectors.models import Document
from danswer.indexing.models import DocAwareChunk
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import shared_precompare_cleanup

View File

@@ -4,7 +4,6 @@ from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
@@ -14,8 +13,7 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.batching import batch_list
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
@@ -66,12 +64,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
)
def embed_chunks(
self,
chunks: list[DocAwareChunk],
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[IndexChunk]:
# Cache the Title embeddings to only have to do it once
@@ -80,7 +78,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunk_texts = []
chunk_texts: list[str] = []
chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks):
chunk_texts.append(chunk.content)
@@ -92,22 +90,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
# Batching for embedding
text_batches = batch_list(chunk_texts, batch_size)
embeddings: list[list[float]] = []
len_text_batches = len(text_batches)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
)
# Replace line above with the line below for easy debugging of indexing flow
# skipping the actual model
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
embeddings = self.embedding_model.encode(
chunk_texts, text_type=EmbedTextType.PASSAGE
)
chunk_titles = {
chunk.source_document.get_title_for_document_index() for chunk in chunks
@@ -116,16 +101,15 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# Drop any None or empty strings
chunk_titles_list = [title for title in chunk_titles if title]
# Embed Titles in batches
title_batches = batch_list(chunk_titles_list, batch_size)
len_title_batches = len(title_batches)
for ind_batch, title_batch in enumerate(title_batches, start=1):
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
if chunk_titles_list:
title_embeddings = self.embedding_model.encode(
title_batch, text_type=EmbedTextType.PASSAGE
chunk_titles_list, text_type=EmbedTextType.PASSAGE
)
title_embed_dict.update(
{title: vector for title, vector in zip(title_batch, title_embeddings)}
{
title: vector
for title, vector in zip(chunk_titles_list, title_embeddings)
}
)
# Mapping embeddings to chunks

View File

@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import message_generator_to_string_generator
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)

View File

@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow

View File

@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection

View File

@@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
),
model_names=model_names,
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
update_default_provider(db_session, llm_provider.id)

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from collections.abc import Iterator
from copy import copy
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -16,10 +15,8 @@ from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from tiktoken.core import Encoding
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
@@ -28,7 +25,6 @@ from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from shared_configs.configs import LOG_LEVEL
@@ -37,53 +33,6 @@ if TYPE_CHECKING:
logger = setup_logger()
_LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
return _LLM_TOKENIZER
def get_default_llm_token_encode() -> Callable[[str], Any]:
global _LLM_TOKENIZER_ENCODE
if _LLM_TOKENIZER_ENCODE is None:
tokenizer = get_default_llm_tokenizer()
if isinstance(tokenizer, Encoding):
return tokenizer.encode # type: ignore
# Currently only supports OpenAI encoder
raise ValueError("Invalid Encoder selected")
return _LLM_TOKENIZER_ENCODE
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
) -> list[InferenceChunk]:
tokenizer = get_default_llm_tokenizer()
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],

View File

@@ -50,8 +50,8 @@ from danswer.db.standard_answer import create_initial_default_standard_answer_ca
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router

View File

@@ -1,14 +1,13 @@
import gc
import os
import time
from typing import Optional
from typing import TYPE_CHECKING
import requests
from transformers import logging as transformer_logging # type:ignore
from httpx import HTTPError
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.natural_language_processing.utils import get_default_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
@@ -20,50 +19,13 @@ from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
logger = setup_logger()
if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
return _TOKENIZER[0]
def build_model_server_url(
model_server_host: str,
model_server_port: int,
@@ -91,6 +53,7 @@ class EmbeddingModel:
provider_type: str | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
retrim_content: bool = False,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -99,32 +62,91 @@ class EmbeddingModel:
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
self.retrim_content = retrim_content
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
if text_type == EmbedTextType.QUERY and self.query_prefix:
prefixed_texts = [self.query_prefix + text for text in texts]
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
prefixed_texts = [self.passage_prefix + text for text in texts]
else:
prefixed_texts = texts
def encode(
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
) -> list[list[float]]:
if not texts:
logger.warning("No texts to be embedded")
return []
embed_request = EmbedRequest(
model_name=self.model_name,
texts=prefixed_texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
)
if self.retrim_content:
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
# However this slight miscounting is very unlikely to have any material impact.
texts = [
tokenizer_trim_content(
content=text,
desired_length=self.max_seq_length,
tokenizer=get_default_tokenizer(),
)
for text in texts
]
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
response.raise_for_status()
if self.provider_type:
embed_request = EmbedRequest(
model_name=self.model_name,
texts=texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
return EmbedResponse(**response.json()).embeddings
return EmbedResponse(**response.json()).embeddings
# Batching for local embedding
text_batches = batch_list(texts, batch_size)
embeddings: list[list[float]] = []
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(EmbedResponse(**response.json()).embeddings)
return embeddings
class CrossEncoderEnsembleModel:
@@ -199,7 +221,7 @@ def warm_up_encoders(
# First time downloading the models it may take even longer, but just in case,
# retry the whole server
wait_time = 5
for attempt in range(20):
for _ in range(20):
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
return

View File

@@ -0,0 +1,100 @@
import gc
import os
from collections.abc import Callable
from copy import copy
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
import tiktoken
from tiktoken.core import Encoding
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
_LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
return _TOKENIZER[0]
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
return _LLM_TOKENIZER
def get_default_llm_token_encode() -> Callable[[str], Any]:
global _LLM_TOKENIZER_ENCODE
if _LLM_TOKENIZER_ENCODE is None:
tokenizer = get_default_llm_tokenizer()
if isinstance(tokenizer, Encoding):
return tokenizer.encode # type: ignore
# Currently only supports OpenAI encoder
raise ValueError("Invalid Encoder selected")
return _LLM_TOKENIZER_ENCODE
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
) -> list[InferenceChunk]:
tokenizer = get_default_llm_tokenizer()
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks

View File

@@ -34,7 +34,7 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase

View File

@@ -2,7 +2,7 @@ from collections.abc import Callable
from collections.abc import Generator
from danswer.configs.constants import MessageType
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import ThreadMessage
from danswer.utils.logger import setup_logger

View File

@@ -34,6 +34,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
logger = setup_logger()
@@ -154,6 +155,7 @@ class SearchPipeline:
return cast(list[InferenceChunk], self._retrieved_chunks)
@log_function_time(print_only=True)
def _get_sections(self) -> list[InferenceSection]:
"""Returns an expanded section from each of the chunks.
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
@@ -173,9 +175,11 @@ class SearchPipeline:
expanded_inference_sections = []
# Full doc setting takes priority
if self.search_query.full_doc:
seen_document_ids = set()
unique_chunks = []
# This preserves the ordering since the chunks are retrieved in score order
for chunk in retrieved_chunks:
if chunk.document_id not in seen_document_ids:
@@ -195,7 +199,6 @@ class SearchPipeline:
),
)
)
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
@@ -240,32 +243,35 @@ class SearchPipeline:
merged_ranges = [
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
]
flat_ranges = [r for ranges in merged_ranges for r in ranges]
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
flattened_inference_chunks: list[InferenceChunk] = []
parallel_functions_with_args = []
for chunk_range in flat_ranges:
functions_with_args.append(
(
# If Large Chunks are introduced, additional filters need to be added here
self.document_index.id_based_retrieval,
(
# Only need the document_id here, just use any chunk in the range is fine
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
# There is no chunk level permissioning, this expansion around chunks
# can be assumed to be safe
IndexFilters(access_control_list=None),
),
)
)
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
if above == below == 0:
flattened_inference_chunks.extend(chunk_range.chunks)
# list of list of inference chunks where the inner list needs to be combined for content
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
flattened_inference_chunks = [
chunk for sublist in list_inference_chunks for chunk in sublist
]
else:
parallel_functions_with_args.append(
(
self.document_index.id_based_retrieval,
(
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
IndexFilters(access_control_list=None),
),
)
)
if parallel_functions_with_args:
list_inference_chunks = run_functions_tuples_in_parallel(
parallel_functions_with_args, allow_failures=False
)
for inference_chunks in list_inference_chunks:
flattened_inference_chunks.extend(inference_chunks)
doc_chunk_ind_to_chunk = {
(chunk.document_id, chunk.chunk_id): chunk

View File

@@ -12,6 +12,9 @@ from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import (
CrossEncoderEnsembleModel,
)
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
@@ -20,7 +23,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall

View File

@@ -1,10 +1,10 @@
from typing import TYPE_CHECKING
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
from danswer.natural_language_processing.search_nlp_models import IntentModel
from danswer.search.enums import QueryFlow
from danswer.search.models import SearchType
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.search.search_nlp_models import IntentModel
from danswer.server.query_and_chat.models import HelperResponse
from danswer.utils.logger import setup_logger

View File

@@ -11,6 +11,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
@@ -20,7 +21,6 @@ from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger

View File

@@ -9,7 +9,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.server.documents.models import ChunkInfo

View File

@@ -10,7 +10,7 @@ from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.embedding.models import TestEmbeddingRequest
@@ -42,7 +42,7 @@ def test_embedding_configuration(
passage_prefix=None,
model_name=None,
)
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
except ValueError as e:
error_msg = f"Not a valid embedding model. Exception thrown: {e}"

View File

@@ -147,10 +147,10 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
_: User | None = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [
LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
]

View File

@@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
custom_config: dict[str, str] | None
default_model_name: str
fast_default_model_name: str | None
is_public: bool = True
groups: list[int] | None = None
class LLMProviderUpsertRequest(LLMProvider):
@@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name]
),
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
)

View File

@@ -45,7 +45,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
)

View File

@@ -6,7 +6,7 @@ from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import ToolMessage
from pydantic import BaseModel
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
def build_tool_message(

View File

@@ -2,7 +2,7 @@ import json
from tiktoken import Encoding
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.tools.tool import Tool

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
@@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
db_session.delete(user__user_group_relationship)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
@@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)

View File

@@ -10,6 +10,7 @@ from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
@@ -40,110 +41,131 @@ router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
# If we are not only indexing, dont want retry very long
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
) -> Any:
if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key)
elif provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=api_key)
elif provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=api_key)
elif provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(api_key)
)
project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
else:
raise ValueError(f"Unsupported provider: {provider}")
class CloudEmbedding:
def __init__(self, api_key: str, provider: str, model: str | None = None):
self.api_key = api_key
def __init__(
self,
api_key: str,
provider: str,
# Only for Google as is needed on client setup
self.model = model
model: str | None = None,
) -> None:
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.client = self._initialize_client()
self.client = _initialize_client(api_key, self.provider, model)
def _initialize_client(self) -> Any:
if self.provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=self.api_key)
elif self.provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=self.api_key)
elif self.provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=self.api_key)
elif self.provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(
self.model or DEFAULT_VERTEX_MODEL
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def encode(
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
) -> list[list[float]]:
return [
self.embed(text=text, text_type=text_type, model=model_name)
for text in texts
]
def embed(
self, *, text: str, text_type: EmbedTextType, model: str | None = None
) -> list[float]:
logger.debug(f"Embedding text with provider: {self.provider}")
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _embed_openai(self, text: str, model: str | None) -> list[float]:
def _embed_openai(self, texts: list[str], model: str | None) -> list[list[float]]:
if model is None:
model = DEFAULT_OPENAI_MODEL
response = self.client.embeddings.create(input=text, model=model)
return response.data[0].embedding
# OpenAI does not seem to provide truncation option, however
# the context lengths used by Danswer currently are smaller than the max token length
# for OpenAI embeddings so it's not a big deal
response = self.client.embeddings.create(input=texts, model=model)
return [embedding.embedding for embedding in response.data]
def _embed_cohere(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
if model is None:
model = DEFAULT_COHERE_MODEL
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
texts=[text],
texts=texts,
model=model,
input_type=embedding_type,
truncate="END",
)
return response.embeddings[0]
return response.embeddings
def _embed_voyage(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
response = self.client.embed(text, model=model, input_type=embedding_type)
return response.embeddings[0]
# Similar to Cohere, the API server will do approximate size chunking
# it's acceptable to miss by a few tokens
response = self.client.embed(
texts,
model=model,
input_type=embedding_type,
truncation=True, # Also this is default
)
return response.embeddings
def _embed_vertex(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float]]:
if model is None:
model = DEFAULT_VERTEX_MODEL
embedding = self.client.get_embeddings(
embeddings = self.client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
]
for text in texts
],
auto_truncate=True, # Also this is default
)
return embedding[0].values
return [embedding.values for embedding in embeddings]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed(
self,
*,
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
) -> list[list[float]]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
@staticmethod
def create(
@@ -212,29 +234,52 @@ def embed_text(
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
prefix: str | None,
) -> list[list[float]]:
# Third party API based embedding model
if provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
raise RuntimeError("API key not provided for cloud model")
if prefix:
# This may change in the future if some providers require the user
# to manually append a prefix but this is not the case currently
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.encode(texts, model_name, text_type)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
text_type=text_type,
)
# Locally running model
elif model_name is not None:
hosted_model = get_embedding_model(
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = hosted_model.encode(
texts, normalize_embeddings=normalize_embeddings
embeddings = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
)
else:
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
if embeddings is None:
raise RuntimeError("Embeddings were not created")
raise RuntimeError("Failed to create Embeddings")
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
return embeddings
@@ -252,7 +297,17 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
async def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
try:
if embed_request.text_type == EmbedTextType.QUERY:
prefix = embed_request.manual_query_prefix
elif embed_request.text_type == EmbedTextType.PASSAGE:
prefix = embed_request.manual_passage_prefix
else:
prefix = None
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
@@ -261,13 +316,13 @@ async def process_embed_request(
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
logger.exception(f"Error during embedding process:\n{str(e)}")
raise HTTPException(
status_code=500, detail="Failed to run Bi-Encoder embedding"
)
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
@router.post("/cross-encoder-scores")
@@ -276,6 +331,11 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
if not embed_request.documents or not embed_request.query:
raise HTTPException(
status_code=400, detail="No documents or query to be reranked"
)
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents

View File

@@ -1,6 +1,7 @@
fastapi==0.109.2
h5py==3.9.0
pydantic==1.10.13
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
tensorflow==2.15.0
@@ -9,5 +10,5 @@ transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
openai==1.14.3
cohere==5.5.8
google-cloud-aiplatform==1.58.0
cohere==5.6.1
google-cloud-aiplatform==1.58.0

View File

@@ -14,7 +14,10 @@ sys.path.append(parent_dir)
# flake8: noqa: E402
# Now import Danswer modules
from danswer.db.models import DocumentSet__ConnectorCredentialPair
from danswer.db.models import (
DocumentSet__ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair,
)
from danswer.db.connector import fetch_connector_by_id
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.index_attempt import (
@@ -44,7 +47,7 @@ logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def unsafe_deletion(
def _unsafe_deletion(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
@@ -82,11 +85,22 @@ def unsafe_deletion(
credential_id=credential_id,
)
# Delete document sets + connector / credential Pairs
# Delete document sets
stmt = delete(DocumentSet__ConnectorCredentialPair).where(
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == pair_id
)
db_session.execute(stmt)
# delete user group associations
stmt = delete(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.cc_pair_id == pair_id
)
db_session.execute(stmt)
# need to flush to avoid foreign key violations
db_session.flush()
# delete the actual connector credential pair
stmt = delete(ConnectorCredentialPair).where(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
@@ -168,7 +182,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
files_deleted_count = unsafe_deletion(
files_deleted_count = _unsafe_deletion(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,

View File

@@ -4,9 +4,7 @@ from shared_configs.enums import EmbedTextType
class EmbedRequest(BaseModel):
# This already includes any prefixes, the text is just passed directly to the model
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None
max_context_length: int
@@ -14,6 +12,8 @@ class EmbedRequest(BaseModel):
api_key: str | None
provider_type: str | None
text_type: EmbedTextType
manual_query_prefix: str | None
manual_passage_prefix: str | None
class EmbedResponse(BaseModel):

View File

@@ -61,6 +61,8 @@ Edit `search_test_config.yaml` to set:
- Set this to true to automatically delete all docker containers, networks and volumes after the test
- launch_web_ui
- Set this to true if you want to use the UI during/after the testing process
- only_state
- Whether to only run Vespa and Postgres
- use_cloud_gpu
- Set to true or false depending on if you want to use the remote gpu
- Only need to set this if use_cloud_gpu is true

View File

@@ -39,7 +39,17 @@ def _create_new_chat_session(run_suffix: str) -> int:
raise RuntimeError(response_json)
@retry(tries=10, delay=10)
def _delete_chat_session(chat_session_id: int, run_suffix: str) -> None:
delete_chat_url = _api_url_builder(
run_suffix, f"/chat/delete-chat-session/{chat_session_id}"
)
response = requests.delete(delete_chat_url, headers=GENERAL_HEADERS)
if response.status_code != 200:
raise RuntimeError(response.__dict__)
@retry(tries=5, delay=5)
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
filters = IndexFilters(
source_type=None,
@@ -78,11 +88,13 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
print("trying again")
raise e
_delete_chat_session(chat_session_id, run_suffix)
return simple_search_docs, answer
@retry(tries=10, delay=10)
def check_if_query_ready(run_suffix: str) -> bool:
def check_indexing_status(run_suffix: str) -> tuple[int, bool]:
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
try:
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
@@ -98,16 +110,17 @@ def check_if_query_ready(run_suffix: str) -> bool:
status = index_attempt["last_status"]
if status == IndexingStatus.IN_PROGRESS or status == IndexingStatus.NOT_STARTED:
ongoing_index_attempts = True
elif status == IndexingStatus.SUCCESS:
doc_count += 16
doc_count += index_attempt["docs_indexed"]
doc_count -= 16
if not doc_count:
print("No docs indexed, waiting for indexing to start")
elif ongoing_index_attempts:
print(
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
)
return doc_count > 0 and not ongoing_index_attempts
# all the +16 and -16 are to account for the fact that the indexing status
# is only updated every 16 documents and will tells us how many are
# chunked, not indexed. probably need to fix this. in the future!
if doc_count:
doc_count += 16
return doc_count, ongoing_index_attempts
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:

View File

@@ -283,39 +283,13 @@ def get_api_server_host_port(suffix: str) -> str:
return matching_ports[0]
# Added function to check Vespa container health status
def is_vespa_container_healthy(suffix: str) -> bool:
print(f"Checking health status of Vespa container for suffix: {suffix}")
# Find the Vespa container
stdout, _ = _run_command(
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
)
container_name = stdout.strip()
if not container_name:
print(f"No Vespa container found with suffix: {suffix}")
return False
# Get the health status
stdout, _ = _run_command(
f"docker inspect --format='{{{{.State.Health.Status}}}}' {container_name}"
)
health_status = stdout.strip()
is_healthy = health_status.lower() == "healthy"
print(f"Vespa container '{container_name}' health status: {health_status}")
return is_healthy
# Added function to restart Vespa container
def restart_vespa_container(suffix: str) -> None:
print(f"Restarting Vespa container for suffix: {suffix}")
# Find the Vespa container
stdout, _ = _run_command(
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
f"docker ps -a --format '{{{{.Names}}}}' | awk '/index-1/ && /{suffix}/'"
)
container_name = stdout.strip()
@@ -327,11 +301,7 @@ def restart_vespa_container(suffix: str) -> None:
print(f"Vespa container '{container_name}' has begun restarting")
time_to_wait = 5
while not is_vespa_container_healthy(suffix):
print(f"Waiting {time_to_wait} seconds for vespa container to restart")
time.sleep(5)
time.sleep(30)
print(f"Vespa container '{container_name}' has been restarted")

View File

@@ -1,13 +1,38 @@
import os
import tempfile
import time
import zipfile
from pathlib import Path
from types import SimpleNamespace
import yaml
from tests.regression.answer_quality.api_utils import check_indexing_status
from tests.regression.answer_quality.api_utils import create_cc_pair
from tests.regression.answer_quality.api_utils import create_connector
from tests.regression.answer_quality.api_utils import create_credential
from tests.regression.answer_quality.api_utils import run_cc_once
from tests.regression.answer_quality.api_utils import upload_file
from tests.regression.answer_quality.cli_utils import restart_vespa_container
def unzip_and_get_file_paths(zip_file_path: str) -> list[str]:
persistent_dir = tempfile.mkdtemp()
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(persistent_dir)
return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()]
def create_temp_zip_from_files(file_paths: list[str]) -> str:
persistent_dir = tempfile.mkdtemp()
zip_file_path = os.path.join(persistent_dir, "temp.zip")
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
for file_path in file_paths:
zip_file.write(file_path, Path(file_path).name)
return zip_file_path
def upload_test_files(zip_file_path: str, run_suffix: str) -> None:
@@ -21,6 +46,40 @@ def upload_test_files(zip_file_path: str, run_suffix: str) -> None:
run_cc_once(run_suffix, conn_id, cred_id)
def manage_file_upload(zip_file_path: str, run_suffix: str) -> None:
unzipped_file_paths = unzip_and_get_file_paths(zip_file_path)
total_file_count = len(unzipped_file_paths)
while True:
doc_count, ongoing_index_attempts = check_indexing_status(run_suffix)
if not doc_count:
print("No docs indexed, waiting for indexing to start")
upload_test_files(zip_file_path, run_suffix)
elif ongoing_index_attempts:
print(
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
)
elif doc_count < total_file_count:
print(f"No ongooing indexing attempts but only {doc_count} docs indexed")
print("Restarting vespa...")
restart_vespa_container(run_suffix)
print(f"Rerunning with {total_file_count - doc_count} missing docs")
remaining_files = unzipped_file_paths[doc_count:]
print(f"Grabbed last {len(remaining_files)} docs to try agian")
temp_zip_file_path = create_temp_zip_from_files(remaining_files)
upload_test_files(temp_zip_file_path, run_suffix)
os.unlink(temp_zip_file_path)
else:
print(f"Successfully uploaded {doc_count} docs!")
break
time.sleep(10)
for file in unzipped_file_paths:
os.unlink(file)
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, "search_test_config.yaml")
@@ -28,4 +87,4 @@ if __name__ == "__main__":
config = SimpleNamespace(**yaml.safe_load(file))
file_location = config.zipped_documents_file
run_suffix = config.existing_test_suffix
upload_test_files(file_location, run_suffix)
manage_file_upload(file_location, run_suffix)

View File

@@ -6,7 +6,6 @@ import time
import yaml
from tests.regression.answer_quality.api_utils import check_if_query_ready
from tests.regression.answer_quality.api_utils import get_answer_from_query
from tests.regression.answer_quality.cli_utils import get_current_commit_sha
from tests.regression.answer_quality.cli_utils import get_docker_container_env_vars
@@ -142,27 +141,23 @@ def _process_and_write_query_results(config: dict) -> None:
test_output_folder, questions = _initialize_files(config)
print("saving test results to folder:", test_output_folder)
while not check_if_query_ready(config["run_suffix"]):
time.sleep(5)
if config["limit"] is not None:
questions = questions[: config["limit"]]
with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
# Use multiprocessing to process questions
with multiprocessing.Pool() as pool:
results = pool.starmap(
_process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
_process_question,
[(question, config, i + 1) for i, question in enumerate(questions)],
)
_populate_results_file(test_output_folder, results)
invalid_answer_count = 0
for result in results:
if not result.get("answer"):
if len(result["context_data_list"]) == 0:
invalid_answer_count += 1
if not result.get("context_data_list"):
raise RuntimeError("Search failed, this is a critical failure!")
_update_metadata_file(test_output_folder, invalid_answer_count)
if invalid_answer_count:

View File

@@ -1,5 +1,6 @@
import { LoadingAnimation } from "@/components/Loading";
import { Button, Divider, Text } from "@tremor/react";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import {
ArrayHelpers,
ErrorMessage,
@@ -15,11 +16,16 @@ import {
SubLabel,
TextArrayField,
TextFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr";
import { useUserGroups } from "@/lib/hooks";
import { FullLLMProvider } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@@ -44,9 +50,16 @@ export function CustomLLMProviderUpdateForm({
}) {
const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
@@ -61,6 +74,8 @@ export function CustomLLMProviderUpdateForm({
custom_config_list: existingLlmProvider?.custom_config
? Object.entries(existingLlmProvider.custom_config)
: [],
is_public: existingLlmProvider?.is_public ?? true,
groups: existingLlmProvider?.groups ?? [],
};
// Setup validation schema if required
@@ -74,6 +89,9 @@ export function CustomLLMProviderUpdateForm({
default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(),
custom_config_list: Yup.array(),
// EE Only
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
});
return (
@@ -97,6 +115,9 @@ export function CustomLLMProviderUpdateForm({
return;
}
// don't set groups if marked as public
const groups = values.is_public ? [] : values.groups;
// test the configuration
if (!isEqual(values, initialValues)) {
setIsTesting(true);
@@ -188,93 +209,97 @@ export function CustomLLMProviderUpdateForm({
setSubmitting(false);
}}
>
{({ values }) => (
<Form>
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
/>
{({ values, setFieldValue }) => {
return (
<Form>
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
/>
<Divider />
<Divider />
<TextFormField
name="provider"
label="Provider Name"
subtext={
<TextFormField
name="provider"
label="Provider Name"
subtext={
<>
Should be one of the providers listed at{" "}
<a
target="_blank"
href="https://docs.litellm.ai/docs/providers"
className="text-link"
>
https://docs.litellm.ai/docs/providers
</a>
.
</>
}
placeholder="Name of the custom provider"
/>
<Divider />
<SubLabel>
Fill in the following as is needed. Refer to the LiteLLM
documentation for the model provider name specified above in order
to determine which fields are required.
</SubLabel>
<TextFormField
name="api_key"
label="[Optional] API Key"
placeholder="API Key"
type="password"
/>
<TextFormField
name="api_base"
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<>
Should be one of the providers listed at{" "}
<a
target="_blank"
href="https://docs.litellm.ai/docs/providers"
className="text-link"
>
https://docs.litellm.ai/docs/providers
</a>
.
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
</div>
<div className="mt-2">
For example, when configuring the Cloudflare provider, you
would need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
Cloudflare account ID as the value.
</div>
</>
}
placeholder="Name of the custom provider"
/>
</SubLabel>
<Divider />
<SubLabel>
Fill in the following as is needed. Refer to the LiteLLM
documentation for the model provider name specified above in order
to determine which fields are required.
</SubLabel>
<TextFormField
name="api_key"
label="[Optional] API Key"
placeholder="API Key"
type="password"
/>
<TextFormField
name="api_base"
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<>
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
</div>
<div className="mt-2">
For example, when configuring the Cloudflare provider, you would
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
Cloudflare account ID as the value.
</div>
</>
</SubLabel>
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
{values.custom_config_list.map((_, index) => {
return (
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
<Field
name={`custom_config_list[${index}][0]`}
className={`
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
{values.custom_config_list.map((_, index) => {
return (
<div
key={index}
className={index === 0 ? "mt-2" : "mt-6"}
>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
<Field
name={`custom_config_list[${index}][0]`}
className={`
border
border-border
bg-background
@@ -284,20 +309,20 @@ export function CustomLLMProviderUpdateForm({
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][0]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][0]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label>Value</Label>
<Field
name={`custom_config_list[${index}][1]`}
className={`
<div className="mt-3">
<Label>Value</Label>
<Field
name={`custom_config_list[${index}][1]`}
className={`
border
border-border
bg-background
@@ -307,121 +332,190 @@ export function CustomLLMProviderUpdateForm({
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][1]`}
component="div"
className="text-error text-sm mt-1"
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][1]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
</div>
);
})}
);
})}
<Button
onClick={() => {
arrayHelpers.push(["", ""]);
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
<Button
onClick={() => {
arrayHelpers.push(["", ""]);
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
<Divider />
<Divider />
<TextArrayField
name="model_names"
label="Model Names"
values={values}
subtext={`List the individual models that you want to make
<TextArrayField
name="model_names"
label="Model Names"
values={values}
subtext={`List the individual models that you want to make
available as a part of this provider. At least one must be specified.
As an example, for OpenAI one model might be "gpt-4".`}
/>
/>
<Divider />
<Divider />
<TextFormField
name="default_model_name"
subtext={`
<TextFormField
name="default_model_name"
subtext={`
The model to use by default for this provider unless
otherwise specified. Must be one of the models listed
above.`}
label="Default Model"
placeholder="E.g. gpt-4"
/>
label="Default Model"
placeholder="E.g. gpt-4"
/>
<TextFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
<TextFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If not set, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
<Divider />
<Divider />
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && <Text className="text-error mt-2">{testError}</Text>}
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
<div className="flex w-full mt-4">
<Button type="submit" size="xs">
{isTesting ? (
<LoadingAnimation text="Testing" />
) : existingLlmProvider ? (
"Update"
) : (
"Enable"
{showAdvancedOptions && (
<>
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
noPadding
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups &&
userGroups.length > 0 &&
!values.is_public && (
<div>
<Text>
Select which User Groups should have access to this
LLM Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)}
</Button>
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
</>
)}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && (
<Text className="text-error mt-2">{testError}</Text>
)}
<div className="flex w-full mt-4">
<Button type="submit" size="xs">
{isTesting ? (
<LoadingAnimation text="Testing" />
) : existingLlmProvider ? (
"Update"
) : (
"Enable"
)}
</Button>
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
)}
</div>
</div>
</div>
</Form>
)}
</Form>
);
}}
</Formik>
);
}

View File

@@ -1,4 +1,5 @@
import { LoadingAnimation } from "@/components/Loading";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { Button, Divider, Text } from "@tremor/react";
import { Form, Formik } from "formik";
import { FiTrash } from "react-icons/fi";
@@ -6,11 +7,16 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import {
SelectorFormField,
TextFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr";
import { useUserGroups } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@@ -29,9 +35,16 @@ export function LLMProviderUpdateForm({
}) {
const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
@@ -54,6 +67,8 @@ export function LLMProviderUpdateForm({
},
{} as { [key: string]: string }
),
is_public: existingLlmProvider?.is_public ?? true,
groups: existingLlmProvider?.groups ?? [],
};
const [validatedConfig, setValidatedConfig] = useState(
@@ -91,6 +106,9 @@ export function LLMProviderUpdateForm({
: {}),
default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(),
// EE Only
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
});
return (
@@ -193,7 +211,7 @@ export function LLMProviderUpdateForm({
setSubmitting(false);
}}
>
{({ values }) => (
{({ values, setFieldValue }) => (
<Form>
<TextFormField
name="name"
@@ -293,6 +311,69 @@ export function LLMProviderUpdateForm({
<Divider />
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
{showAdvancedOptions && (
<>
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
noPadding
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups && userGroups.length > 0 && !values.is_public && (
<div>
<Text>
Select which User Groups should have access to this LLM
Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)}
</>
)}
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && <Text className="text-error mt-2">{testError}</Text>}

View File

@@ -17,6 +17,8 @@ export interface WellKnownLLMProviderDescriptor {
llm_names: string[];
default_model: string | null;
default_fast_model: string | null;
is_public: boolean;
groups: number[];
}
export interface LLMProvider {
@@ -28,6 +30,8 @@ export interface LLMProvider {
custom_config: { [key: string]: string } | null;
default_model_name: string;
fast_default_model_name: string | null;
is_public: boolean;
groups: number[];
}
export interface FullLLMProvider extends LLMProvider {
@@ -44,4 +48,6 @@ export interface LLMProviderDescriptor {
default_model_name: string;
fast_default_model_name: string | null;
is_default_provider: boolean | null;
is_public: boolean;
groups: number[];
}

View File

@@ -50,7 +50,7 @@ export default async function GalleryPage({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availablePersonas: assistants,
availableAssistants: assistants,
availableTags: tags,
llmProviders,
folders,

View File

@@ -52,7 +52,7 @@ export default async function GalleryPage({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availablePersonas: assistants,
availableAssistants: assistants,
availableTags: tags,
llmProviders,
folders,

View File

@@ -14,15 +14,18 @@ export function ChatBanner() {
return (
<div
className={`
z-[39]
h-[30px]
bg-background-100
shadow-sm
m-2
rounded
border-border
border
flex`}
mt-8
mb-2
mx-2
z-[39]
w-full
h-[30px]
bg-background-100
shadow-sm
rounded
border-border
border
flex`}
>
<div className="mx-auto text-emphasis text-sm flex flex-col">
<div className="my-auto">

View File

@@ -63,7 +63,6 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
import { ConfigurationModal } from "./modal/configuration/ConfigurationModal";
import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid";
import { orderAssistantsForUser } from "@/lib/assistants/orderAssistants";
@@ -82,38 +81,29 @@ const SYSTEM_MESSAGE_ID = -3;
export function ChatPage({
toggle,
documentSidebarInitialWidth,
defaultSelectedPersonaId,
defaultSelectedAssistantId,
toggledSidebar,
}: {
toggle: () => void;
documentSidebarInitialWidth?: number;
defaultSelectedPersonaId?: number;
defaultSelectedAssistantId?: number;
toggledSidebar: boolean;
}) {
const [configModalActiveTab, setConfigModalActiveTab] = useState<
string | null
>(null);
const router = useRouter();
const searchParams = useSearchParams();
let {
user,
chatSessions,
availableSources,
availableDocumentSets,
availablePersonas,
availableAssistants,
llmProviders,
folders,
openedFolders,
} = useChatContext();
const filteredAssistants = orderAssistantsForUser(availablePersonas, user);
const [selectedAssistant, setSelectedAssistant] = useState<Persona | null>(
null
);
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
useState<Persona | null>(null);
const router = useRouter();
const searchParams = useSearchParams();
// chat session
const existingChatIdRaw = searchParams.get("chatId");
const existingChatSessionId = existingChatIdRaw
? parseInt(existingChatIdRaw)
@@ -123,9 +113,51 @@ export function ChatPage({
);
const chatSessionIdRef = useRef<number | null>(existingChatSessionId);
// LLM
const llmOverrideManager = useLlmOverride(selectedChatSession);
const existingChatSessionPersonaId = selectedChatSession?.persona_id;
// Assistants
const filteredAssistants = orderAssistantsForUser(availableAssistants, user);
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
const [selectedAssistant, setSelectedAssistant] = useState<
Persona | undefined
>(
// NOTE: look through available assistants here, so that even if the user
// has hidden this assistant it still shows the correct assistant when
// going back to an old chat session
existingChatSessionAssistantId !== undefined
? availableAssistants.find(
(assistant) => assistant.id === existingChatSessionAssistantId
)
: defaultSelectedAssistantId !== undefined
? availableAssistants.find(
(assistant) => assistant.id === defaultSelectedAssistantId
)
: undefined
);
const setSelectedAssistantFromId = (assistantId: number) => {
// NOTE: also intentionally look through available assistants here, so that
// even if the user has hidden an assistant they can still go back to it
// for old chats
setSelectedAssistant(
availableAssistants.find((assistant) => assistant.id === assistantId)
);
};
const liveAssistant =
selectedAssistant || filteredAssistants[0] || availableAssistants[0];
// this is for "@"ing assistants
const [alternativeAssistant, setAlternativeAssistant] =
useState<Persona | null>(null);
// this is used to track which assistant is being used to generate the current message
// for example, this would come into play when:
// 1. default assistant is `Danswer`
// 2. we "@"ed the `GPT` assistant and sent a message
// 3. while the `GPT` assistant message is generating, we "@" the `Paraphrase` assistant
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
useState<Persona | null>(null);
// used to track whether or not the initial "submit on load" has been performed
// this only applies if `?submit-on-load=true` or `?submit-on-load=1` is in the URL
@@ -182,14 +214,10 @@ export function ChatPage({
async function initialSessionFetch() {
if (existingChatSessionId === null) {
setIsFetchingChatMessages(false);
if (defaultSelectedPersonaId !== undefined) {
setSelectedPersona(
filteredAssistants.find(
(persona) => persona.id === defaultSelectedPersonaId
)
);
if (defaultSelectedAssistantId !== undefined) {
setSelectedAssistantFromId(defaultSelectedAssistantId);
} else {
setSelectedPersona(undefined);
setSelectedAssistant(undefined);
}
setCompleteMessageDetail({
sessionId: null,
@@ -214,12 +242,7 @@ export function ChatPage({
);
const chatSession = (await response.json()) as BackendChatSession;
setSelectedPersona(
filteredAssistants.find(
(persona) => persona.id === chatSession.persona_id
)
);
setSelectedAssistantFromId(chatSession.persona_id);
const newMessageMap = processRawChatHistory(chatSession.messages);
const newMessageHistory = buildLatestMessageChain(newMessageMap);
@@ -373,32 +396,18 @@ export function ChatPage({
)
: { aiMessage: null };
const [selectedPersona, setSelectedPersona] = useState<Persona | undefined>(
existingChatSessionPersonaId !== undefined
? filteredAssistants.find(
(persona) => persona.id === existingChatSessionPersonaId
)
: defaultSelectedPersonaId !== undefined
? filteredAssistants.find(
(persona) => persona.id === defaultSelectedPersonaId
)
: undefined
);
const livePersona =
selectedPersona || filteredAssistants[0] || availablePersonas[0];
const [chatSessionSharedStatus, setChatSessionSharedStatus] =
useState<ChatSessionSharedStatus>(ChatSessionSharedStatus.Private);
useEffect(() => {
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
setSelectedPersona(
setSelectedAssistant(
filteredAssistants.find(
(persona) => persona.id === defaultSelectedPersonaId
(persona) => persona.id === defaultSelectedAssistantId
)
);
}
}, [defaultSelectedPersonaId]);
}, [defaultSelectedAssistantId]);
const [
selectedDocuments,
@@ -414,7 +423,7 @@ export function ChatPage({
useEffect(() => {
async function fetchMaxTokens() {
const response = await fetch(
`/api/chat/max-selected-document-tokens?persona_id=${livePersona.id}`
`/api/chat/max-selected-document-tokens?persona_id=${liveAssistant.id}`
);
if (response.ok) {
const maxTokens = (await response.json()).max_tokens as number;
@@ -423,12 +432,12 @@ export function ChatPage({
}
fetchMaxTokens();
}, [livePersona]);
}, [liveAssistant]);
const filterManager = useFilters();
const [finalAvailableSources, finalAvailableDocumentSets] =
computeAvailableFilters({
selectedPersona,
selectedPersona: selectedAssistant,
availableSources,
availableDocumentSets,
});
@@ -624,16 +633,16 @@ export function ChatPage({
queryOverride,
forceSearch,
isSeededChat,
alternativeAssistant = null,
alternativeAssistantOverride = null,
}: {
messageIdToResend?: number;
messageOverride?: string;
queryOverride?: string;
forceSearch?: boolean;
isSeededChat?: boolean;
alternativeAssistant?: Persona | null;
alternativeAssistantOverride?: Persona | null;
} = {}) => {
setAlternativeGeneratingAssistant(alternativeAssistant);
setAlternativeGeneratingAssistant(alternativeAssistantOverride);
clientScrollToBottom();
let currChatSessionId: number;
@@ -643,7 +652,7 @@ export function ChatPage({
if (isNewSession) {
currChatSessionId = await createChatSession(
livePersona?.id || 0,
liveAssistant?.id || 0,
searchParamBasedChatSessionName
);
} else {
@@ -721,9 +730,9 @@ export function ChatPage({
parentMessage = frozenMessageMap.get(SYSTEM_MESSAGE_ID) || null;
}
const currentAssistantId = alternativeAssistant
? alternativeAssistant.id
: selectedAssistant?.id;
const currentAssistantId = alternativeAssistantOverride
? alternativeAssistantOverride.id
: alternativeAssistant?.id || liveAssistant.id;
resetInputBar();
@@ -751,7 +760,7 @@ export function ChatPage({
fileDescriptors: currentMessageFiles,
parentMessageId: lastSuccessfulMessageId,
chatSessionId: currChatSessionId,
promptId: livePersona?.prompts[0]?.id || 0,
promptId: liveAssistant?.prompts[0]?.id || 0,
filters: buildFilters(
filterManager.selectedSources,
filterManager.selectedDocumentSets,
@@ -868,7 +877,7 @@ export function ChatPage({
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
parentMessageId: newUserMessageId,
alternateAssistantID: selectedAssistant?.id,
alternateAssistantID: alternativeAssistant?.id,
},
]);
}
@@ -964,19 +973,23 @@ export function ChatPage({
}
};
const onPersonaChange = (persona: Persona | null) => {
if (persona && persona.id !== livePersona.id) {
const onAssistantChange = (assistant: Persona | null) => {
if (assistant && assistant.id !== liveAssistant.id) {
// remove uploaded files
setCurrentMessageFiles([]);
setSelectedPersona(persona);
setSelectedAssistant(assistant);
textAreaRef.current?.focus();
router.push(buildChatUrl(searchParams, null, persona.id));
router.push(buildChatUrl(searchParams, null, assistant.id));
}
};
const handleImageUpload = (acceptedFiles: File[]) => {
const llmAcceptsImages = checkLLMSupportsImageInput(
...getFinalLLM(llmProviders, livePersona, llmOverrideManager.llmOverride)
...getFinalLLM(
llmProviders,
liveAssistant,
llmOverrideManager.llmOverride
)
);
const imageFiles = acceptedFiles.filter((file) =>
file.type.startsWith("image/")
@@ -1058,23 +1071,23 @@ export function ChatPage({
useEffect(() => {
const includes = checkAnyAssistantHasSearch(
messageHistory,
availablePersonas,
livePersona
availableAssistants,
liveAssistant
);
setRetrievalEnabled(includes);
}, [messageHistory, availablePersonas, livePersona]);
}, [messageHistory, availableAssistants, liveAssistant]);
const [retrievalEnabled, setRetrievalEnabled] = useState(() => {
return checkAnyAssistantHasSearch(
messageHistory,
availablePersonas,
livePersona
availableAssistants,
liveAssistant
);
});
const innerSidebarElementRef = useRef<HTMLDivElement>(null);
const currentPersona = selectedAssistant || livePersona;
const currentPersona = alternativeAssistant || liveAssistant;
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
@@ -1176,21 +1189,8 @@ export function ChatPage({
/>
)}
<ConfigurationModal
chatSessionId={chatSessionIdRef.current!}
activeTab={configModalActiveTab}
setActiveTab={setConfigModalActiveTab}
onClose={() => setConfigModalActiveTab(null)}
filterManager={filterManager}
availableAssistants={filteredAssistants}
selectedAssistant={livePersona}
setSelectedAssistant={onPersonaChange}
llmProviders={llmProviders}
llmOverrideManager={llmOverrideManager}
/>
<div className="flex h-[calc(100dvh)] flex-col w-full">
{livePersona && (
{liveAssistant && (
<FunctionalHeader
page="chat"
setSharingModalVisible={
@@ -1203,6 +1203,23 @@ export function ChatPage({
currentChatSession={selectedChatSession}
/>
)}
<div className="w-full flex">
<div
style={{ transition: "width 0.30s ease-out" }}
className={`
flex-none
overflow-y-hidden
bg-background-100
transition-all
bg-opacity-80
duration-300
ease-in-out
h-full
${toggledSidebar || showDocSidebar ? "w-[300px]" : "w-[0px]"}
`}
/>
<ChatBanner />
</div>
{documentSidebarInitialWidth !== undefined ? (
<Dropzone onDrop={handleImageUpload} noClick>
{({ getRootProps }) => (
@@ -1231,15 +1248,14 @@ export function ChatPage({
ref={scrollableDivRef}
>
{/* ChatBanner is a custom banner that displays a admin-specified message at
the top of the chat page. Only used in the EE version of the app. */}
<ChatBanner />
the top of the chat page. Oly used in the EE version of the app. */}
{messageHistory.length === 0 &&
!isFetchingChatMessages &&
!isStreaming && (
<ChatIntro
availableSources={finalAvailableSources}
selectedPersona={livePersona}
selectedPersona={liveAssistant}
/>
)}
<div
@@ -1319,7 +1335,7 @@ export function ChatPage({
const currentAlternativeAssistant =
message.alternateAssistantID != null
? availablePersonas.find(
? availableAssistants.find(
(persona) =>
persona.id ==
message.alternateAssistantID
@@ -1342,7 +1358,7 @@ export function ChatPage({
toggleDocumentSelectionAspects
}
docs={message.documents}
currentPersona={livePersona}
currentPersona={liveAssistant}
alternativeAssistant={
currentAlternativeAssistant
}
@@ -1352,7 +1368,7 @@ export function ChatPage({
query={
messageHistory[i]?.query || undefined
}
personaName={livePersona.name}
personaName={liveAssistant.name}
citedDocuments={getCitedDocumentsFromMessage(
message
)}
@@ -1404,7 +1420,7 @@ export function ChatPage({
messageIdToResend:
previousMessage.messageId,
queryOverride: newQuery,
alternativeAssistant:
alternativeAssistantOverride:
currentAlternativeAssistant,
});
}
@@ -1435,7 +1451,7 @@ export function ChatPage({
messageIdToResend:
previousMessage.messageId,
forceSearch: true,
alternativeAssistant:
alternativeAssistantOverride:
currentAlternativeAssistant,
});
} else {
@@ -1460,9 +1476,9 @@ export function ChatPage({
return (
<div key={messageReactComponentKey}>
<AIMessage
currentPersona={livePersona}
currentPersona={liveAssistant}
messageId={message.messageId}
personaName={livePersona.name}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
@@ -1481,13 +1497,13 @@ export function ChatPage({
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
>
<AIMessage
currentPersona={livePersona}
currentPersona={liveAssistant}
alternativeAssistant={
alternativeGeneratingAssistant ??
selectedAssistant
alternativeAssistant
}
messageId={null}
personaName={livePersona.name}
personaName={liveAssistant.name}
content={
<div className="text-sm my-auto">
<ThreeDots
@@ -1513,7 +1529,7 @@ export function ChatPage({
{currentPersona &&
currentPersona.starter_messages &&
currentPersona.starter_messages.length > 0 &&
selectedPersona &&
selectedAssistant &&
messageHistory.length === 0 &&
!isFetchingChatMessages && (
<div
@@ -1570,32 +1586,25 @@ export function ChatPage({
<ChatInputBar
showDocs={() => setDocumentSelection(true)}
selectedDocuments={selectedDocuments}
setSelectedAssistant={onPersonaChange}
onSetSelectedAssistant={(
alternativeAssistant: Persona | null
) => {
setSelectedAssistant(alternativeAssistant);
}}
alternativeAssistant={selectedAssistant}
personas={filteredAssistants}
// assistant stuff
assistantOptions={filteredAssistants}
selectedAssistant={liveAssistant}
setSelectedAssistant={onAssistantChange}
setAlternativeAssistant={setAlternativeAssistant}
alternativeAssistant={alternativeAssistant}
// end assistant stuff
message={message}
setMessage={setMessage}
onSubmit={onSubmit}
isStreaming={isStreaming}
setIsCancelled={setIsCancelled}
retrievalDisabled={
!personaIncludesRetrieval(currentPersona)
}
filterManager={filterManager}
llmOverrideManager={llmOverrideManager}
selectedAssistant={livePersona}
files={currentMessageFiles}
setFiles={setCurrentMessageFiles}
handleFileUpload={handleImageUpload}
setConfigModalActiveTab={setConfigModalActiveTab}
textAreaRef={textAreaRef}
chatSessionId={chatSessionIdRef.current!}
availableAssistants={availablePersonas}
/>
</div>
</div>

View File

@@ -5,10 +5,10 @@ import { ChatPage } from "./ChatPage";
import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper";
export default function WrappedChat({
defaultPersonaId,
defaultAssistantId,
initiallyToggled,
}: {
defaultPersonaId?: number;
defaultAssistantId?: number;
initiallyToggled: boolean;
}) {
return (
@@ -17,7 +17,7 @@ export default function WrappedChat({
content={(toggledSidebar, toggle) => (
<ChatPage
toggle={toggle}
defaultSelectedPersonaId={defaultPersonaId}
defaultSelectedAssistantId={defaultAssistantId}
toggledSidebar={toggledSidebar}
/>
)}

View File

@@ -21,7 +21,6 @@ import { IconType } from "react-icons";
import Popup from "../../../components/popup/Popup";
import { LlmTab } from "../modal/configuration/LlmTab";
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
import ChatInputAssistant from "./ChatInputAssistant";
import { DanswerDocument } from "@/lib/search/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Tooltip } from "@/components/tooltip/Tooltip";
@@ -29,7 +28,6 @@ import { Hoverable } from "@/components/Hoverable";
const MAX_INPUT_HEIGHT = 200;
export function ChatInputBar({
personas,
showDocs,
selectedDocuments,
message,
@@ -37,34 +35,32 @@ export function ChatInputBar({
onSubmit,
isStreaming,
setIsCancelled,
retrievalDisabled,
filterManager,
llmOverrideManager,
onSetSelectedAssistant,
selectedAssistant,
files,
// assistants
selectedAssistant,
assistantOptions,
setSelectedAssistant,
setAlternativeAssistant,
files,
setFiles,
handleFileUpload,
setConfigModalActiveTab,
textAreaRef,
alternativeAssistant,
chatSessionId,
availableAssistants,
}: {
showDocs: () => void;
selectedDocuments: DanswerDocument[];
availableAssistants: Persona[];
onSetSelectedAssistant: (alternativeAssistant: Persona | null) => void;
assistantOptions: Persona[];
setAlternativeAssistant: (alternativeAssistant: Persona | null) => void;
setSelectedAssistant: (assistant: Persona) => void;
personas: Persona[];
message: string;
setMessage: (message: string) => void;
onSubmit: () => void;
isStreaming: boolean;
setIsCancelled: (value: boolean) => void;
retrievalDisabled: boolean;
filterManager: FilterManager;
llmOverrideManager: LlmOverrideManager;
selectedAssistant: Persona;
@@ -72,7 +68,6 @@ export function ChatInputBar({
files: FileDescriptor[];
setFiles: (files: FileDescriptor[]) => void;
handleFileUpload: (files: File[]) => void;
setConfigModalActiveTab: (tab: string) => void;
textAreaRef: React.RefObject<HTMLTextAreaElement>;
chatSessionId?: number;
}) {
@@ -136,8 +131,10 @@ export function ChatInputBar({
};
// Update selected persona
const updateCurrentPersona = (persona: Persona) => {
onSetSelectedAssistant(persona.id == selectedAssistant.id ? null : persona);
const updatedTaggedAssistant = (assistant: Persona) => {
setAlternativeAssistant(
assistant.id == selectedAssistant.id ? null : assistant
);
hideSuggestions();
setMessage("");
};
@@ -160,8 +157,8 @@ export function ChatInputBar({
}
};
const filteredPersonas = personas.filter((persona) =>
persona.name.toLowerCase().startsWith(
const assistantTagOptions = assistantOptions.filter((assistant) =>
assistant.name.toLowerCase().startsWith(
message
.slice(message.lastIndexOf("@") + 1)
.split(/\s/)[0]
@@ -174,18 +171,18 @@ export function ChatInputBar({
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (
showSuggestions &&
filteredPersonas.length > 0 &&
assistantTagOptions.length > 0 &&
(e.key === "Tab" || e.key == "Enter")
) {
e.preventDefault();
if (assistantIconIndex == filteredPersonas.length) {
if (assistantIconIndex == assistantTagOptions.length) {
window.open("/assistants/new", "_blank");
hideSuggestions();
setMessage("");
} else {
const option =
filteredPersonas[assistantIconIndex >= 0 ? assistantIconIndex : 0];
updateCurrentPersona(option);
assistantTagOptions[assistantIconIndex >= 0 ? assistantIconIndex : 0];
updatedTaggedAssistant(option);
}
}
if (!showSuggestions) {
@@ -195,7 +192,7 @@ export function ChatInputBar({
if (e.key === "ArrowDown") {
e.preventDefault();
setAssistantIconIndex((assistantIconIndex) =>
Math.min(assistantIconIndex + 1, filteredPersonas.length)
Math.min(assistantIconIndex + 1, assistantTagOptions.length)
);
} else if (e.key === "ArrowUp") {
e.preventDefault();
@@ -219,35 +216,36 @@ export function ChatInputBar({
mx-auto
"
>
{showSuggestions && filteredPersonas.length > 0 && (
{showSuggestions && assistantTagOptions.length > 0 && (
<div
ref={suggestionsRef}
className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full"
>
<div className="rounded-lg py-1.5 bg-background border border-border-medium shadow-lg mx-2 px-1.5 mt-2 rounded z-10">
{filteredPersonas.map((currentPersona, index) => (
{assistantTagOptions.map((currentAssistant, index) => (
<button
key={index}
className={`px-2 ${
assistantIconIndex == index && "bg-hover-lightish"
} rounded rounded-lg content-start flex gap-x-1 py-2 w-full hover:bg-hover-lightish cursor-pointer`}
onClick={() => {
updateCurrentPersona(currentPersona);
updatedTaggedAssistant(currentAssistant);
}}
>
<p className="font-bold">{currentPersona.name}</p>
<p className="font-bold">{currentAssistant.name}</p>
<p className="line-clamp-1">
{currentPersona.id == selectedAssistant.id &&
{currentAssistant.id == selectedAssistant.id &&
"(default) "}
{currentPersona.description}
{currentAssistant.description}
</p>
</button>
))}
<a
key={filteredPersonas.length}
key={assistantTagOptions.length}
target="_blank"
className={`${
assistantIconIndex == filteredPersonas.length && "bg-hover"
assistantIconIndex == assistantTagOptions.length &&
"bg-hover"
} rounded rounded-lg px-3 flex gap-x-1 py-2 w-full items-center hover:bg-hover-lightish cursor-pointer"`}
href="/assistants/new"
>
@@ -301,7 +299,7 @@ export function ChatInputBar({
<Hoverable
icon={FiX}
onClick={() => onSetSelectedAssistant(null)}
onClick={() => setAlternativeAssistant(null)}
/>
</div>
</div>
@@ -409,7 +407,7 @@ export function ChatInputBar({
removePadding
content={(close) => (
<AssistantsTab
availableAssistants={availableAssistants}
availableAssistants={assistantOptions}
llmProviders={llmProviders}
selectedAssistant={selectedAssistant}
onSelect={(assistant) => {

View File

@@ -505,7 +505,7 @@ export function removeMessage(
export function checkAnyAssistantHasSearch(
messageHistory: Message[],
availablePersonas: Persona[],
availableAssistants: Persona[],
livePersona: Persona
): boolean {
const response =
@@ -516,8 +516,8 @@ export function checkAnyAssistantHasSearch(
) {
return false;
}
const alternateAssistant = availablePersonas.find(
(persona) => persona.id === message.alternateAssistantID
const alternateAssistant = availableAssistants.find(
(assistant) => assistant.id === message.alternateAssistantID
);
return alternateAssistant
? personaIncludesRetrieval(alternateAssistant)

View File

@@ -58,7 +58,6 @@ import { ValidSources } from "@/lib/types";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { useMouseTracking } from "./hooks";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { getTitleFromDocument } from "@/lib/sources";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -445,7 +444,8 @@ export const AIMessage = ({
>
<Citation link={doc.link} index={ind + 1} />
<p className="shrink truncate ellipsis break-all ">
{getTitleFromDocument(doc)}
{doc.semantic_identifier ||
doc.document_id}
</p>
<div className="ml-auto flex-none">
{doc.is_internet ? (
@@ -798,7 +798,7 @@ export const HumanMessage = ({
!isEditing &&
(!files || files.length === 0)
) && "ml-auto"
} relative max-w-[70%] mb-auto rounded-3xl bg-user px-5 py-2.5`}
} relative max-w-[70%] mb-auto whitespace-break-spaces rounded-3xl bg-user px-5 py-2.5`}
>
{content}
</div>

View File

@@ -62,11 +62,10 @@ export function AssistantsTab({
toolName = "Image Generation";
toolIcon = <FiImage className="mr-1 my-auto" />;
}
return (
<Bubble key={tool.id} isSelected={false}>
<div className="flex flex-row gap-1">
{toolIcon}
<div className="flex line-wrap break-all flex-row gap-1">
<div className="flex-none my-auto">{toolIcon}</div>
{toolName}
</div>
</Bubble>

View File

@@ -1,180 +0,0 @@
"use client";
import React, { useEffect } from "react";
import { Modal } from "../../../../components/Modal";
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
import { FiltersTab } from "./FiltersTab";
import { FiCpu, FiFilter, FiX } from "react-icons/fi";
import { IconType } from "react-icons";
import { FaBrain } from "react-icons/fa";
import { AssistantsTab } from "./AssistantsTab";
import { Persona } from "@/app/admin/assistants/interfaces";
import { LlmTab } from "./LlmTab";
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
import { AssistantsIcon, IconProps } from "@/components/icons/icons";
const TabButton = ({
label,
icon: Icon,
isActive,
onClick,
}: {
label: string;
icon: IconType;
isActive: boolean;
onClick: () => void;
}) => (
<button
onClick={onClick}
className={`
pb-4
pt-6
px-2
text-emphasis
font-bold
${isActive ? "border-b-2 border-accent" : ""}
hover:bg-hover-light
hover:text-strong
transition
duration-200
ease-in-out
flex
`}
>
<Icon className="inline-block mr-2 my-auto" size="16" />
<p className="my-auto">{label}</p>
</button>
);
export function ConfigurationModal({
activeTab,
setActiveTab,
onClose,
availableAssistants,
selectedAssistant,
setSelectedAssistant,
filterManager,
llmProviders,
llmOverrideManager,
chatSessionId,
}: {
activeTab: string | null;
setActiveTab: (tab: string | null) => void;
onClose: () => void;
availableAssistants: Persona[];
selectedAssistant: Persona;
setSelectedAssistant: (assistant: Persona) => void;
filterManager: FilterManager;
llmProviders: LLMProviderDescriptor[];
llmOverrideManager: LlmOverrideManager;
chatSessionId?: number;
}) {
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === "Escape") {
onClose();
}
};
document.addEventListener("keydown", handleKeyDown);
return () => {
document.removeEventListener("keydown", handleKeyDown);
};
}, [onClose]);
if (!activeTab) return null;
return (
<Modal
onOutsideClick={onClose}
noPadding
className="
w-4/6
h-4/6
flex
flex-col
"
>
<div className="rounded flex flex-col overflow-hidden">
<div className="mb-4">
<div className="flex border-b border-border bg-background-emphasis">
<div className="flex px-6 gap-x-2">
<TabButton
label="Assistants"
icon={FaBrain}
isActive={activeTab === "assistants"}
onClick={() => setActiveTab("assistants")}
/>
<TabButton
label="Models"
icon={FiCpu}
isActive={activeTab === "llms"}
onClick={() => setActiveTab("llms")}
/>
<TabButton
label="Filters"
icon={FiFilter}
isActive={activeTab === "filters"}
onClick={() => setActiveTab("filters")}
/>
</div>
<button
className="
ml-auto
px-1
py-1
text-xs
font-medium
rounded
hover:bg-hover
focus:outline-none
focus:ring-2
focus:ring-offset-2
focus:ring-subtle
flex
items-center
h-fit
my-auto
mr-5
"
onClick={onClose}
>
<FiX size={24} />
</button>
</div>
</div>
<div className="flex flex-col overflow-y-auto">
<div className="px-8 pt-4">
{activeTab === "filters" && (
<FiltersTab filterManager={filterManager} />
)}
{activeTab === "llms" && (
<LlmTab
chatSessionId={chatSessionId}
llmOverrideManager={llmOverrideManager}
currentAssistant={selectedAssistant}
/>
)}
{activeTab === "assistants" && (
<div>
<AssistantsTab
availableAssistants={availableAssistants}
llmProviders={llmProviders}
selectedAssistant={selectedAssistant}
onSelect={(assistant) => {
setSelectedAssistant(assistant);
onClose();
}}
/>
</div>
)}
</div>
</div>
</div>
</Modal>
);
}

View File

@@ -35,7 +35,7 @@ export default async function Page({
folders,
toggleSidebar,
openedFolders,
defaultPersonaId,
defaultAssistantId,
finalDocumentSidebarInitialWidth,
shouldShowWelcomeModal,
shouldDisplaySourcesIncompleteModal,
@@ -58,7 +58,7 @@ export default async function Page({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availablePersonas: assistants,
availableAssistants: assistants,
availableTags: tags,
llmProviders,
folders,
@@ -66,7 +66,7 @@ export default async function Page({
}}
>
<WrappedChat
defaultPersonaId={defaultPersonaId}
defaultAssistantId={defaultAssistantId}
initiallyToggled={toggleSidebar}
/>
</ChatProvider>

View File

@@ -1,111 +0,0 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { BasicSelectable } from "@/components/BasicClickable";
import { AssistantsIcon } from "@/components/icons/icons";
import { User } from "@/lib/types";
import { Text } from "@tremor/react";
import Link from "next/link";
import { FaRobot } from "react-icons/fa";
import { FiEdit2 } from "react-icons/fi";
function AssistantDisplay({
persona,
onSelect,
user,
}: {
persona: Persona;
onSelect: (persona: Persona) => void;
user: User | null;
}) {
const isEditable =
(!user || user.id === persona.owner?.id) &&
!persona.default_persona &&
(!persona.is_public || !user || user.role === "admin");
return (
<div className="flex">
<div className="w-full" onClick={() => onSelect(persona)}>
<BasicSelectable selected={false} fullWidth>
<div className="flex">
<div className="truncate w-48 3xl:w-56 flex">
<AssistantsIcon className="mr-2 my-auto" size={16} />{" "}
{persona.name}
</div>
</div>
</BasicSelectable>
</div>
{isEditable && (
<div className="pl-2 my-auto">
<Link href={`/assistants/edit/${persona.id}`}>
<FiEdit2
className="my-auto ml-auto hover:bg-hover p-0.5"
size={20}
/>
</Link>
</div>
)}
</div>
);
}
export function AssistantsTab({
personas,
onPersonaChange,
user,
}: {
personas: Persona[];
onPersonaChange: (persona: Persona | null) => void;
user: User | null;
}) {
const globalAssistants = personas.filter((persona) => persona.is_public);
const personalAssistants = personas.filter(
(persona) =>
(!user || persona.users.some((u) => u.id === user.id)) &&
!persona.is_public
);
return (
<div className="mt-4 pb-1 overflow-y-auto h-full flex flex-col gap-y-1">
<Text className="mx-3 text-xs mb-4">
Select an Assistant below to begin a new chat with them!
</Text>
<div className="mx-3">
{globalAssistants.length > 0 && (
<>
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 font-bold">
Global
</div>
{globalAssistants.map((persona) => {
return (
<AssistantDisplay
key={persona.id}
persona={persona}
onSelect={onPersonaChange}
user={user}
/>
);
})}
</>
)}
{personalAssistants.length > 0 && (
<>
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 mt-5 font-bold">
Personal
</div>
{personalAssistants.map((persona) => {
return (
<AssistantDisplay
key={persona.id}
persona={persona}
onSelect={onPersonaChange}
user={user}
/>
);
})}
</>
)}
</div>
</div>
);
}

View File

@@ -99,22 +99,28 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
h-screen
transition-transform`}
>
<div className="ml-4 mr-3 flex flex gap-x-1 items-center mt-2 my-auto text-text-700 text-xl">
<div className="mr-1 my-auto h-6 w-6">
<div className="max-w-full ml-3 mr-3 mt-2 flex flex gap-x-1 items-center my-auto text-text-700 text-xl">
<div className="mr-1 mb-auto h-6 w-6">
<Logo height={24} width={24} />
</div>
<div className="invisible">
{enterpriseSettings && enterpriseSettings.application_name ? (
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
<div>
<HeaderTitle>
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Danswer</p>
)}
</div>
) : (
<HeaderTitle>Danswer</HeaderTitle>
)}
</div>
{toggleSidebar && (
<Tooltip delayDuration={1000} content={`${commandSymbol}E show`}>
<button className="ml-auto" onClick={toggleSidebar}>
<button className="mb-auto ml-auto" onClick={toggleSidebar}>
{!toggled ? <RightToLineIcon /> : <LefToLineIcon />}
</button>
</Tooltip>

View File

@@ -3,6 +3,7 @@
import { HeaderTitle } from "@/components/header/Header";
import { Logo } from "@/components/Logo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import { useContext } from "react";
export default function FixedLogo() {
@@ -11,17 +12,24 @@ export default function FixedLogo() {
const enterpriseSettings = combinedSettings?.enterpriseSettings;
return (
<div className="fixed flex z-40 left-4 top-2">
{" "}
<a href="/chat" className="ml-7 text-text-700 text-xl">
<div>
<div className="absolute flex z-40 left-2.5 top-2">
<div className="max-w-[200px] flex gap-x-1 my-auto">
<div className="flex-none invisible mb-auto">
<Logo />
</div>
<div className="">
{enterpriseSettings && enterpriseSettings.application_name ? (
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
<div>
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Danswer</p>
)}
</div>
) : (
<HeaderTitle>Danswer</HeaderTitle>
)}
</div>
</a>
</div>
</div>
);
}

View File

@@ -10,12 +10,11 @@ const ToggleSwitch = () => {
const commandSymbol = KeyboardSymbol();
const pathname = usePathname();
const router = useRouter();
const [activeTab, setActiveTab] = useState(() => {
if (typeof window !== "undefined") {
return localStorage.getItem("activeTab") || "chat";
}
return "chat";
return pathname == "/search" ? "search" : "chat";
});
const [isInitialLoad, setIsInitialLoad] = useState(true);
useEffect(() => {

View File

@@ -168,7 +168,8 @@ export default async function Home() {
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
) &&
!shouldDisplayNoSourcesModal &&
!shouldShowWelcomeModal;
!shouldShowWelcomeModal &&
(!user || user.role == "admin");
const sidebarToggled = cookies().get(SIDEBAR_TOGGLED_COOKIE_NAME);
const agenticSearchToggle = cookies().get(AGENTIC_SEARCH_TYPE_COOKIE_NAME);
@@ -178,7 +179,7 @@ export default async function Home() {
: false;
const agenticSearchEnabled = agenticSearchToggle
? agenticSearchToggle.value.toLocaleLowerCase() == "true" || true
? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false
: false;
return (

View File

@@ -0,0 +1,26 @@
import React from "react";
import { Button } from "@tremor/react";
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
interface AdvancedOptionsToggleProps {
showAdvancedOptions: boolean;
setShowAdvancedOptions: (show: boolean) => void;
}
export function AdvancedOptionsToggle({
showAdvancedOptions,
setShowAdvancedOptions,
}: AdvancedOptionsToggleProps) {
return (
<Button
type="button"
variant="light"
size="xs"
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
className="mb-4 text-xs text-text-500 hover:text-text-400"
>
Advanced Options
</Button>
);
}

View File

@@ -31,7 +31,10 @@ export function Logo({
}
return (
<div style={{ height, width }} className={`relative ${className}`}>
<div
style={{ height, width }}
className={`flex-none relative ${className}`}
>
{/* TODO: figure out how to use Next Image here */}
<img
src="/api/enterprise-settings/logo"

View File

@@ -91,7 +91,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
return (
<div className="h-screen overflow-y-hidden">
<div className="flex h-full">
<div className="w-64 z-20 bg-background-100 pt-4 pb-8 h-full border-r border-border miniscroll overflow-auto">
<div className="w-64 z-20 bg-background-100 pt-3 pb-8 h-full border-r border-border miniscroll overflow-auto">
<AdminSidebar
collections={[
{

View File

@@ -28,9 +28,9 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
return (
<aside className="pl-0">
<nav className="space-y-2 pl-4">
<div className="pb-12 flex">
<div className="fixed left-0 top-0 py-2 pl-4 bg-background-100 w-[200px]">
<nav className="space-y-2 pl-2">
<div className="mb-4 flex">
<div className="bg-background-100">
<Link
className="flex flex-col"
href={
@@ -39,8 +39,8 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
: "/search"
}
>
<div className="flex gap-x-1 my-auto">
<div className="my-auto">
<div className="max-w-[200px] flex gap-x-1 my-auto">
<div className="flex-none mb-auto">
<Logo />
</div>
<div className="my-auto">
@@ -50,7 +50,7 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle -mt-1.5">
<p className="text-xs text-subtle">
Powered by Danswer
</p>
)}
@@ -63,14 +63,16 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
</Link>
</div>
</div>
<Link href={"/chat"}>
<button className="text-sm block w-48 py-2.5 flex px-2 text-left bg-background-200 hover:bg-background-200/80 cursor-pointer rounded">
<BackIcon size={20} className="text-neutral" />
<p className="ml-1">Back to Danswer</p>
</button>
</Link>
<div className="px-3">
<Link href={"/chat"}>
<button className="text-sm block w-48 py-2.5 flex px-2 text-left bg-background-200 hover:bg-background-200/80 cursor-pointer rounded">
<BackIcon size={20} className="text-neutral" />
<p className="ml-1">Back to Danswer</p>
</button>
</Link>
</div>
{collections.map((collection, collectionInd) => (
<div key={collectionInd}>
<div className="px-3" key={collectionInd}>
<h2 className="text-xs text-strong font-bold pb-2">
<div>{collection.name}</div>
</h2>

View File

@@ -59,18 +59,23 @@ export default function FunctionalHeader({
<div className="pb-6 left-0 sticky top-0 z-10 w-full relative flex">
<div className="mt-2 mx-4 text-text-700 flex w-full">
<div className="absolute z-[100] my-auto flex items-center text-xl font-bold">
<FiSidebar size={20} />
<div className="ml-2 text-text-700 text-xl">
{enterpriseSettings && enterpriseSettings.application_name ? (
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
) : (
<HeaderTitle>Danswer</HeaderTitle>
)}
<div className="pt-[2px] mb-auto">
<FiSidebar size={20} />
</div>
<div className="break-words inline-block w-fit ml-2 text-text-700 text-xl">
<div className="max-w-[200px]">
{enterpriseSettings && enterpriseSettings.application_name ? (
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
) : (
<HeaderTitle>Danswer</HeaderTitle>
)}
</div>
</div>
{page == "chat" && (
<Tooltip delayDuration={1000} content={`${commandSymbol}U`}>
<Link
className="mb-auto pt-[2px]"
href={
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
@@ -79,10 +84,9 @@ export default function FunctionalHeader({
: "")
}
>
<NewChatIcon
size={20}
className="ml-2 my-auto cursor-pointer text-text-700 hover:text-text-600 transition-colors duration-300"
/>
<div className="cursor-pointer ml-2 flex-none text-text-700 hover:text-text-600 transition-colors duration-300">
<NewChatIcon size={20} className="" />
</div>
</Link>
</Tooltip>
)}

View File

@@ -18,7 +18,7 @@ interface ChatContextProps {
chatSessions: ChatSession[];
availableSources: ValidSources[];
availableDocumentSets: DocumentSet[];
availablePersonas: Persona[];
availableAssistants: Persona[];
availableTags: Tag[];
llmProviders: LLMProviderDescriptor[];
folders: Folder[];

View File

@@ -11,7 +11,11 @@ import { Logo } from "../Logo";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
export function HeaderTitle({ children }: { children: JSX.Element | string }) {
return <h1 className="flex text-2xl text-strong font-bold">{children}</h1>;
return (
<h1 className="flex text-2xl text-strong leading-none font-bold">
{children}
</h1>
);
}
interface HeaderProps {
@@ -36,8 +40,8 @@ export function Header({ user, page }: HeaderProps) {
settings && settings.default_page === "chat" ? "/chat" : "/search"
}
>
<div className="flex my-auto">
<div className="mr-1 my-auto">
<div className="max-w-[200px] bg-black flex my-auto">
<div className="mr-1 mb-auto">
<Logo />
</div>
<div className="my-auto">
@@ -47,9 +51,7 @@ export function Header({ user, page }: HeaderProps) {
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle -mt-1.5">
Powered by Danswer
</p>
<p className="text-xs text-subtle">Powered by Danswer</p>
)}
</div>
) : (

View File

@@ -19,7 +19,7 @@ export function Citation({
return (
<CustomTooltip
citation
content={<p className="inline-block p-0 m-0 truncate">{link}</p>}
content={<div className="inline-block p-0 m-0 truncate">{link}</div>}
>
<a
onClick={() => (link ? window.open(link, "_blank") : undefined)}

View File

@@ -30,7 +30,6 @@ import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
interface FetchChatDataResult {
user: User | null;
chatSessions: ChatSession[];
ccPairs: CCPairBasicInfo[];
availableSources: ValidSources[];
documentSets: DocumentSet[];
@@ -39,7 +38,7 @@ interface FetchChatDataResult {
llmProviders: LLMProviderDescriptor[];
folders: Folder[];
openedFolders: Record<string, boolean>;
defaultPersonaId?: number;
defaultAssistantId?: number;
toggleSidebar: boolean;
finalDocumentSidebarInitialWidth?: number;
shouldShowWelcomeModal: boolean;
@@ -150,9 +149,9 @@ export async function fetchChatData(searchParams: {
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
}
const defaultPersonaIdRaw = searchParams["assistantId"];
const defaultPersonaId = defaultPersonaIdRaw
? parseInt(defaultPersonaIdRaw)
const defaultAssistantIdRaw = searchParams["assistantId"];
const defaultAssistantId = defaultAssistantIdRaw
? parseInt(defaultAssistantIdRaw)
: undefined;
const documentSidebarCookieInitialWidth = cookies().get(
@@ -178,7 +177,8 @@ export async function fetchChatData(searchParams: {
!shouldShowWelcomeModal &&
!ccPairs.some(
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
);
) &&
(!user || user.role == "admin");
// if no connectors are setup, only show personas that are pure
// passthrough and don't do any retrieval
@@ -209,7 +209,7 @@ export async function fetchChatData(searchParams: {
llmProviders,
folders,
openedFolders,
defaultPersonaId,
defaultAssistantId,
finalDocumentSidebarInitialWidth,
toggleSidebar,
shouldShowWelcomeModal,

View File

@@ -301,9 +301,3 @@ function stripTrailingSlash(str: string) {
}
return str;
}
export const getTitleFromDocument = (document: DanswerDocument) => {
return stripTrailingSlash(document.document_id).split("/")[
stripTrailingSlash(document.document_id).split("/").length - 1
];
};