Compare commits

...

27 Commits

Author SHA1 Message Date
pablodanswer
48a0d29a5c Fix empty / reverted embeddings (#1910) 2024-07-23 22:41:31 -07:00
hagen-danswer
6ff8e6c0ea Improve eval pipeline qol (#1908) 2024-07-23 17:16:34 -07:00
Yuhong Sun
2470c68506 Don't rephrase first chat query (#1907) 2024-07-23 16:20:11 -07:00
hagen-danswer
866bc803b1 Implemented LLM disabling for api call (#1905) 2024-07-23 16:12:51 -07:00
pablodanswer
9c6084bd0d Embeddings- Clean up modal + "Important" call out (#1903) 2024-07-22 21:29:22 -07:00
hagen-danswer
a0b46c60c6 Switched eval api target back to oneshotqa (#1902) 2024-07-22 20:55:18 -07:00
pablodanswer
4029233df0 hide incomplete sources for non-admins (#1901) 2024-07-22 13:40:11 -07:00
hagen-danswer
6c88c0156c Added file upload retry logic (#1889) 2024-07-22 13:13:22 -07:00
pablodanswer
33332d08f2 fix citation title (#1900)
* fix citation title

* remove title function
2024-07-22 17:37:04 +00:00
hagen-danswer
17005fb705 switched default pruning behavior and removed some logging (#1898) 2024-07-22 17:36:26 +00:00
hagen-danswer
48a7fe80b1 Committed LLM updates to db (#1899) 2024-07-22 10:30:24 -07:00
pablodanswer
1276732409 Misc bug fixes (#1895) 2024-07-22 10:22:43 -07:00
Weves
f91b92a898 Make is_public default true for LLMProvider 2024-07-21 22:22:37 -07:00
Weves
6222f533be Update force delete script to handle user groups 2024-07-21 22:22:37 -07:00
hagen-danswer
1b49d17239 Added ability to control LLM access based on group (#1870)
* Added ability to control LLM access based on group

* completed relationship deletion

* cleaned up function

* added comments

* fixed frontend strings

* mypy fixes

* added case handling for deletion of user groups

* hidden advanced options now

* removed unnecessary code
2024-07-22 04:31:44 +00:00
Yuhong Sun
2f5f19642e Double Check Max Tokens for Indexing (#1893) 2024-07-21 21:12:39 -07:00
Yuhong Sun
6db4634871 Token Truncation (#1892) 2024-07-21 16:26:32 -07:00
Yuhong Sun
5cfed45cef Handle Empty Titles (#1891) 2024-07-21 14:59:23 -07:00
Weves
581ffde35a Fix jira connector failures for server deployments 2024-07-21 14:44:25 -07:00
pablodanswer
6313e6d91d Remove visit api when unneded (#1885)
* quick fix to test on ec2

* quick cleanup

* modify a name

* address full doc as well

* additional timing info + handling

* clean up

* squash

* Print only
2024-07-21 20:57:24 +00:00
Weves
c09c94bf32 Fix assistant swap 2024-07-21 13:57:36 -07:00
Yuhong Sun
0e8ba111c8 Model Touchups (#1887) 2024-07-21 12:31:00 -07:00
Yuhong Sun
2ba24b1734 Reenable Search Pipeline (#1886) 2024-07-21 10:33:29 -07:00
Yuhong Sun
44820b4909 k 2024-07-21 10:27:57 -07:00
hagen-danswer
eb3e7610fc Added retries and multithreading for cloud embedding (#1879)
* added retries and multithreading for cloud embedding

* refactored a bit

* cleaned up code

* got the errors to bubble up to the ui correctly

* added exceptin printing

* added requirements

* touchups

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-07-20 22:10:18 -07:00
pablodanswer
7fbbb174bb minor fixes (#1882)
- Assistants tab size
- Fixed logo -> absolute
2024-07-20 21:02:57 -07:00
pablodanswer
3854ca11af add newlines for message content 2024-07-20 18:57:29 -07:00
86 changed files with 1599 additions and 1306 deletions

View File

@@ -0,0 +1,41 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -80,7 +80,7 @@ def should_prune_cc_pair(
return True
return False
if PREVENT_SIMULTANEOUS_PRUNING:
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
@@ -89,11 +89,9 @@ def should_prune_cc_pair(
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
if not last_pruning_task.start_time:

View File

@@ -33,7 +33,7 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable

View File

@@ -51,7 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
@@ -187,37 +187,46 @@ def _handle_internet_search_tool_response_summary(
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
# If files are already provided, don't run the search tool
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
if new_msg_req.file_descriptors:
return None
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
if (
new_msg_req.query_override
or (
should_force_search = any(
[
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
)
or new_msg_req.search_doc_ids
or DISABLE_LLM_CHOOSE_SEARCH
):
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override
else None
)
# if we are using selected docs, just put something here so the Tool doesn't need
# to build its own args via an LLM call
if new_msg_req.search_doc_ids:
args = {"query": new_msg_req.message}
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
return ForceUseTool(
tool_name=SearchTool._NAME,
args=args,
)
return None
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
ChatPacket = (
@@ -253,7 +262,6 @@ def stream_chat_message_objects(
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
try:
user_id = user.id if user is not None else None
@@ -361,6 +369,14 @@ def stream_chat_message_objects(
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
@@ -576,11 +592,7 @@ def stream_chat_message_objects(
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=(
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
force_use_tool=_get_force_search_settings(new_msg_req, tools),
)
reference_db_search_docs = None

View File

@@ -214,8 +214,8 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.

View File

@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
continue
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments]
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
pass
metadata_dict = {}
if jira.fields.priority:
metadata_dict["priority"] = jira.fields.priority.name
if jira.fields.status:
metadata_dict["status"] = jira.fields.status.name
if jira.fields.resolution:
metadata_dict["resolution"] = jira.fields.resolution.name
if jira.fields.labels:
metadata_dict["label"] = jira.fields.labels
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
doc_batch.append(
Document(

View File

@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=100, period=60)
@rate_limit_builder(max_calls=50, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")

View File

@@ -123,6 +123,8 @@ class DocumentBase(BaseModel):
for char in replace_chars:
title = title.replace(char, " ")
title = title.strip()
# Title could be quite long here as there is no truncation done
# just prior to embedding, it could be truncated
return title
def get_metadata_str_attributes(self) -> list[str] | None:

View File

@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST

View File

@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)

View File

@@ -1,15 +1,41 @@
from sqlalchemy import delete
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def update_group_llm_provider_relationships__no_commit(
llm_provider_id: int,
group_ids: list[int] | None,
db_session: Session,
) -> None:
# Delete existing relationships
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
).delete(synchronize_session="fetch")
# Add new relationships from given group_ids
if group_ids:
new_relationships = [
LLMProvider__UserGroup(
llm_provider_id=llm_provider_id,
user_group_id=group_id,
)
for group_id in group_ids
]
db_session.add_all(new_relationships)
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
@@ -36,36 +62,35 @@ def upsert_llm_provider(
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider:
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = (
llm_provider.fast_default_model_name
)
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
default_model_name=llm_provider.default_model_name,
fast_default_model_name=llm_provider.fast_default_model_name,
model_names=llm_provider.model_names,
is_default_provider=None,
if not existing_llm_provider:
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
group_ids=llm_provider.groups,
db_session=db_session,
)
db_session.add(llm_provider_model)
db_session.commit()
return FullLLMProvider.from_model(llm_provider_model)
return FullLLMProvider.from_model(existing_llm_provider)
def fetch_existing_embedding_providers(
@@ -74,8 +99,29 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_subquery = (
select(User__UserGroup.user_group_id)
.where(User__UserGroup.user_id == user.id)
.subquery()
)
access_conditions = or_(
LLMProviderModel.is_public,
LLMProviderModel.id.in_( # User is part of a group that has access
select(LLMProvider__UserGroup.llm_provider_id).where(
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
)
),
)
stmt = stmt.where(access_conditions)
return list(db_session.scalars(stmt).all())
def fetch_embedding_provider(
@@ -119,6 +165,13 @@ def remove_embedding_provider(
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships
db_session.execute(
delete(LLMProvider__UserGroup).where(
LLMProvider__UserGroup.llm_provider_id == provider_id
)
)
# Remove LLMProvider
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)

View File

@@ -932,6 +932,13 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="llm_provider__user_group",
viewonly=True,
)
class CloudEmbeddingProvider(Base):
@@ -1109,7 +1116,6 @@ class Persona(Base):
# where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# These are only defaults, users can select from all if desired
prompts: Mapped[list[Prompt]] = relationship(
@@ -1137,6 +1143,7 @@ class Persona(Base):
viewonly=True,
)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="persona__user_group",
@@ -1360,6 +1367,17 @@ class Persona__UserGroup(Base):
)
class LLMProvider__UserGroup(Base):
__tablename__ = "llm_provider__user_group"
llm_provider_id: Mapped[int] = mapped_column(
ForeignKey("llm_provider.id"), primary_key=True
)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
class DocumentSet__UserGroup(Base):
__tablename__ = "document_set__user_group"

View File

@@ -331,12 +331,18 @@ def _index_vespa_chunk(
document = chunk.source_document
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
embeddings = chunk.embeddings
if chunk.embeddings.full_embedding is None:
embeddings.full_embedding = chunk.title_embedding
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
if embeddings.mini_chunk_embeddings:
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
if m_c_embed is None:
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
else:
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
title = document.get_title_for_document_index()

View File

@@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
default_model_name=default_model_name,
fast_default_model_name=default_fast_model_name,
model_names=None,
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
update_default_provider(db_session, llm_provider.id)

View File

@@ -15,7 +15,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
)
from danswer.connectors.models import Document
from danswer.indexing.models import DocAwareChunk
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import shared_precompare_cleanup

View File

@@ -4,7 +4,6 @@ from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
@@ -14,8 +13,7 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.batching import batch_list
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
@@ -66,21 +64,21 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
)
def embed_chunks(
self,
chunks: list[DocAwareChunk],
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[IndexChunk]:
# Cache the Title embeddings to only have to do it once
title_embed_dict: dict[str, list[float]] = {}
title_embed_dict: dict[str, list[float] | None] = {}
embedded_chunks: list[IndexChunk] = []
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunk_texts = []
chunk_texts: list[str] = []
chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks):
chunk_texts.append(chunk.content)
@@ -92,22 +90,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
# Batching for embedding
text_batches = batch_list(chunk_texts, batch_size)
embeddings: list[list[float]] = []
len_text_batches = len(text_batches)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
)
# Replace line above with the line below for easy debugging of indexing flow
# skipping the actual model
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
embeddings = self.embedding_model.encode(
chunk_texts, text_type=EmbedTextType.PASSAGE
)
chunk_titles = {
chunk.source_document.get_title_for_document_index() for chunk in chunks
@@ -116,16 +101,15 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# Drop any None or empty strings
chunk_titles_list = [title for title in chunk_titles if title]
# Embed Titles in batches
title_batches = batch_list(chunk_titles_list, batch_size)
len_title_batches = len(title_batches)
for ind_batch, title_batch in enumerate(title_batches, start=1):
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
if chunk_titles_list:
title_embeddings = self.embedding_model.encode(
title_batch, text_type=EmbedTextType.PASSAGE
chunk_titles_list, text_type=EmbedTextType.PASSAGE
)
title_embed_dict.update(
{title: vector for title, vector in zip(title_batch, title_embeddings)}
{
title: vector
for title, vector in zip(chunk_titles_list, title_embeddings)
}
)
# Mapping embeddings to chunks
@@ -184,4 +168,6 @@ def get_embedding_model_from_db_embedding_model(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
provider_type=db_embedding_model.provider_type,
api_key=db_embedding_model.api_key,
)

View File

@@ -13,7 +13,7 @@ if TYPE_CHECKING:
logger = setup_logger()
Embedding = list[float]
Embedding = list[float] | None
class ChunkEmbedding(BaseModel):

View File

@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import message_generator_to_string_generator
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
@@ -99,6 +99,7 @@ class Answer:
answer_style_config: AnswerStyleConfig,
llm: LLM,
prompt_config: PromptConfig,
force_use_tool: ForceUseTool,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
@@ -107,10 +108,8 @@ class Answer:
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
tools: list[Tool] | None = None,
# if specified, tells the LLM to always this tool
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
# but we only support them anyways
force_use_tool: ForceUseTool | None = None,
# if set to True, then never use the LLMs provided tool-calling functonality
skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
@@ -129,6 +128,7 @@ class Answer:
self.tools = tools or []
self.force_use_tool = force_use_tool
self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or []
@@ -187,7 +187,7 @@ class Answer:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool and self.force_use_tool.args is not None:
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
@@ -221,7 +221,7 @@ class Answer:
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool else None,
tool_choice="required" if self.force_use_tool.force_use else None,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
@@ -245,7 +245,8 @@ class Answer:
][0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool and self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
@@ -303,7 +304,7 @@ class Answer:
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,

View File

@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow

View File

@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection

View File

@@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
),
model_names=model_names,
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
update_default_provider(db_session, llm_provider.id)

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from collections.abc import Iterator
from copy import copy
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -16,10 +15,8 @@ from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from tiktoken.core import Encoding
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
@@ -28,7 +25,6 @@ from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from shared_configs.configs import LOG_LEVEL
@@ -37,53 +33,6 @@ if TYPE_CHECKING:
logger = setup_logger()
_LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
return _LLM_TOKENIZER
def get_default_llm_token_encode() -> Callable[[str], Any]:
global _LLM_TOKENIZER_ENCODE
if _LLM_TOKENIZER_ENCODE is None:
tokenizer = get_default_llm_tokenizer()
if isinstance(tokenizer, Encoding):
return tokenizer.encode # type: ignore
# Currently only supports OpenAI encoder
raise ValueError("Invalid Encoder selected")
return _LLM_TOKENIZER_ENCODE
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
) -> list[InferenceChunk]:
tokenizer = get_default_llm_tokenizer()
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],

View File

@@ -50,8 +50,8 @@ from danswer.db.standard_answer import create_initial_default_standard_answer_ca
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router

View File

@@ -1,14 +1,13 @@
import gc
import os
import time
from typing import Optional
from typing import TYPE_CHECKING
import requests
from transformers import logging as transformer_logging # type:ignore
from httpx import HTTPError
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.natural_language_processing.utils import get_default_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
@@ -20,50 +19,13 @@ from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
logger = setup_logger()
if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
return _TOKENIZER[0]
def build_model_server_url(
model_server_host: str,
model_server_port: int,
@@ -91,6 +53,7 @@ class EmbeddingModel:
provider_type: str | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
retrim_content: bool = False,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -99,32 +62,90 @@ class EmbeddingModel:
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
self.retrim_content = retrim_content
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
if text_type == EmbedTextType.QUERY and self.query_prefix:
prefixed_texts = [self.query_prefix + text for text in texts]
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
prefixed_texts = [self.passage_prefix + text for text in texts]
else:
prefixed_texts = texts
def encode(
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
) -> list[list[float] | None]:
if not texts:
logger.warning("No texts to be embedded")
return []
embed_request = EmbedRequest(
model_name=self.model_name,
texts=prefixed_texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
)
if self.retrim_content:
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
# However this slight miscounting is very unlikely to have any material impact.
texts = [
tokenizer_trim_content(
content=text,
desired_length=self.max_seq_length,
tokenizer=get_default_tokenizer(),
)
for text in texts
]
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
response.raise_for_status()
if self.provider_type:
embed_request = EmbedRequest(
model_name=self.model_name,
texts=texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
EmbedResponse(**response.json()).embeddings
return EmbedResponse(**response.json()).embeddings
return EmbedResponse(**response.json()).embeddings
# Batching for local embedding
text_batches = batch_list(texts, batch_size)
embeddings: list[list[float] | None] = []
for idx, text_batch in enumerate(text_batches, start=1):
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(EmbedResponse(**response.json()).embeddings)
return embeddings
class CrossEncoderEnsembleModel:
@@ -136,7 +157,7 @@ class CrossEncoderEnsembleModel:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
rerank_request = RerankRequest(query=query, documents=passages)
response = requests.post(
@@ -199,7 +220,7 @@ def warm_up_encoders(
# First time downloading the models it may take even longer, but just in case,
# retry the whole server
wait_time = 5
for attempt in range(20):
for _ in range(20):
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
return

View File

@@ -0,0 +1,100 @@
import gc
import os
from collections.abc import Callable
from copy import copy
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
import tiktoken
from tiktoken.core import Encoding
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
_LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
return _TOKENIZER[0]
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
return _LLM_TOKENIZER
def get_default_llm_token_encode() -> Callable[[str], Any]:
global _LLM_TOKENIZER_ENCODE
if _LLM_TOKENIZER_ENCODE is None:
tokenizer = get_default_llm_tokenizer()
if isinstance(tokenizer, Encoding):
return tokenizer.encode # type: ignore
# Currently only supports OpenAI encoder
raise ValueError("Invalid Encoder selected")
return _LLM_TOKENIZER_ENCODE
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
) -> list[InferenceChunk]:
tokenizer = get_default_llm_tokenizer()
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks

View File

@@ -34,7 +34,7 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
@@ -206,6 +206,7 @@ def stream_answer_objects(
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
),
@@ -256,6 +257,9 @@ def stream_answer_objects(
)
yield initial_response
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
@@ -267,9 +271,12 @@ def stream_answer_objects(
)
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
if query_req.skip_gen_ai_answer_generation:
# Exit early if only source docs + contexts are requested
# Putting exit here assumes that a packet with the ID
# SECTION_RELEVANCE_LIST_ID is the last one yielded before
# calling the LLM
return
elif packet.id == SEARCH_EVALUATION_ID:
evaluation_response = LLMRelevanceSummaryResponse(

View File

@@ -2,7 +2,7 @@ from collections.abc import Callable
from collections.abc import Generator
from danswer.configs.constants import MessageType
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import ThreadMessage
from danswer.utils.logger import setup_logger

View File

@@ -34,6 +34,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
logger = setup_logger()
@@ -154,6 +155,7 @@ class SearchPipeline:
return cast(list[InferenceChunk], self._retrieved_chunks)
@log_function_time(print_only=True)
def _get_sections(self) -> list[InferenceSection]:
"""Returns an expanded section from each of the chunks.
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
@@ -173,9 +175,11 @@ class SearchPipeline:
expanded_inference_sections = []
# Full doc setting takes priority
if self.search_query.full_doc:
seen_document_ids = set()
unique_chunks = []
# This preserves the ordering since the chunks are retrieved in score order
for chunk in retrieved_chunks:
if chunk.document_id not in seen_document_ids:
@@ -195,7 +199,6 @@ class SearchPipeline:
),
)
)
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
@@ -240,32 +243,35 @@ class SearchPipeline:
merged_ranges = [
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
]
flat_ranges = [r for ranges in merged_ranges for r in ranges]
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
flattened_inference_chunks: list[InferenceChunk] = []
parallel_functions_with_args = []
for chunk_range in flat_ranges:
functions_with_args.append(
(
# If Large Chunks are introduced, additional filters need to be added here
self.document_index.id_based_retrieval,
(
# Only need the document_id here, just use any chunk in the range is fine
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
# There is no chunk level permissioning, this expansion around chunks
# can be assumed to be safe
IndexFilters(access_control_list=None),
),
)
)
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
if above == below == 0:
flattened_inference_chunks.extend(chunk_range.chunks)
# list of list of inference chunks where the inner list needs to be combined for content
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
flattened_inference_chunks = [
chunk for sublist in list_inference_chunks for chunk in sublist
]
else:
parallel_functions_with_args.append(
(
self.document_index.id_based_retrieval,
(
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
IndexFilters(access_control_list=None),
),
)
)
if parallel_functions_with_args:
list_inference_chunks = run_functions_tuples_in_parallel(
parallel_functions_with_args, allow_failures=False
)
for inference_chunks in list_inference_chunks:
flattened_inference_chunks.extend(inference_chunks)
doc_chunk_ind_to_chunk = {
(chunk.document_id, chunk.chunk_id): chunk

View File

@@ -12,6 +12,9 @@ from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import (
CrossEncoderEnsembleModel,
)
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
@@ -20,7 +23,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall

View File

@@ -1,10 +1,10 @@
from typing import TYPE_CHECKING
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
from danswer.natural_language_processing.search_nlp_models import IntentModel
from danswer.search.enums import QueryFlow
from danswer.search.models import SearchType
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.search.search_nlp_models import IntentModel
from danswer.server.query_and_chat.models import HelperResponse
from danswer.utils.logger import setup_logger

View File

@@ -1,5 +1,6 @@
import string
from collections.abc import Callable
from typing import cast
import nltk # type:ignore
from nltk.corpus import stopwords # type:ignore
@@ -11,6 +12,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
@@ -20,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
@@ -143,7 +144,9 @@ def doc_index_retrieval(
if query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query=query.query,
query_embedding=query_embedding,
query_embedding=cast(
list[float], query_embedding
), # query embeddings should always have vector representations
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
@@ -152,7 +155,9 @@ def doc_index_retrieval(
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
query_embedding=cast(
list[float], query_embedding
), # query embeddings should always have vector representations
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,

View File

@@ -94,7 +94,7 @@ def history_based_query_rephrase(
llm: LLM,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
skip_first_rephrase: bool = False,
skip_first_rephrase: bool = True,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> str:
# Globally disabled, just use the exact user query

View File

@@ -96,6 +96,8 @@ def upsert_ingestion_doc(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
)
indexing_pipeline = build_indexing_pipeline(
@@ -132,6 +134,8 @@ def upsert_ingestion_doc(
normalize=sec_db_embedding_model.normalize,
query_prefix=sec_db_embedding_model.query_prefix,
passage_prefix=sec_db_embedding_model.passage_prefix,
api_key=sec_db_embedding_model.api_key,
provider_type=sec_db_embedding_model.provider_type,
)
sec_ind_pipeline = build_indexing_pipeline(

View File

@@ -9,7 +9,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.server.documents.models import ChunkInfo

View File

@@ -10,7 +10,7 @@ from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.embedding.models import TestEmbeddingRequest
@@ -42,7 +42,7 @@ def test_embedding_configuration(
passage_prefix=None,
model_name=None,
)
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
except ValueError as e:
error_msg = f"Not a valid embedding model. Exception thrown: {e}"

View File

@@ -147,10 +147,10 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
_: User | None = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [
LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
]

View File

@@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
custom_config: dict[str, str] | None
default_model_name: str
fast_default_model_name: str | None
is_public: bool = True
groups: list[int] | None = None
class LLMProviderUpsertRequest(LLMProvider):
@@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name]
),
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
)

View File

@@ -45,7 +45,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
)

View File

@@ -90,7 +90,7 @@ class CreateChatMessageRequest(ChunkContext):
parent_message_id: int | None
# New message contents
message: str
# file's that we should attach to this message
# Files that we should attach to this message
file_descriptors: list[FileDescriptor]
# If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred

View File

@@ -1,13 +1,15 @@
from typing import Any
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from danswer.tools.tool import Tool
class ForceUseTool(BaseModel):
# Could be not a forced usage of the tool but still have args, in which case
# if the tool is called, then those args are applied instead of what the LLM
# wanted to call it with
force_use: bool
tool_name: str
args: dict[str, Any] | None = None
@@ -16,25 +18,10 @@ class ForceUseTool(BaseModel):
return {"type": "function", "function": {"name": self.tool_name}}
def modify_message_chain_for_force_use_tool(
messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None
) -> list[BaseMessage]:
"""NOTE: modifies `messages` in place."""
if not force_use_tool:
return messages
for message in messages:
if isinstance(message, AIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
tool_call["args"] = force_use_tool.args or {}
return messages
def filter_tools_for_force_tool_use(
tools: list[Tool], force_use_tool: ForceUseTool | None = None
tools: list[Tool], force_use_tool: ForceUseTool
) -> list[Tool]:
if not force_use_tool:
if not force_use_tool.force_use:
return tools
return [tool for tool in tools if tool.name == force_use_tool.tool_name]

View File

@@ -6,7 +6,7 @@ from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import ToolMessage
from pydantic import BaseModel
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
def build_tool_message(

View File

@@ -2,7 +2,7 @@ import json
from tiktoken import Encoding
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.tools.tool import Tool

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
@@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
db_session.delete(user__user_group_relationship)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
@@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)

View File

@@ -10,6 +10,7 @@ from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
@@ -40,110 +41,133 @@ router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
# If we are not only indexing, dont want retry very long
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
) -> Any:
if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key)
elif provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=api_key)
elif provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=api_key)
elif provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(api_key)
)
project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
else:
raise ValueError(f"Unsupported provider: {provider}")
class CloudEmbedding:
def __init__(self, api_key: str, provider: str, model: str | None = None):
self.api_key = api_key
def __init__(
self,
api_key: str,
provider: str,
# Only for Google as is needed on client setup
self.model = model
model: str | None = None,
) -> None:
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.client = self._initialize_client()
self.client = _initialize_client(api_key, self.provider, model)
def _initialize_client(self) -> Any:
if self.provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=self.api_key)
elif self.provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=self.api_key)
elif self.provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=self.api_key)
elif self.provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(
self.model or DEFAULT_VERTEX_MODEL
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def encode(
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
) -> list[list[float]]:
return [
self.embed(text=text, text_type=text_type, model=model_name)
for text in texts
]
def embed(
self, *, text: str, text_type: EmbedTextType, model: str | None = None
) -> list[float]:
logger.debug(f"Embedding text with provider: {self.provider}")
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _embed_openai(self, text: str, model: str | None) -> list[float]:
def _embed_openai(
self, texts: list[str], model: str | None
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_OPENAI_MODEL
response = self.client.embeddings.create(input=text, model=model)
return response.data[0].embedding
# OpenAI does not seem to provide truncation option, however
# the context lengths used by Danswer currently are smaller than the max token length
# for OpenAI embeddings so it's not a big deal
response = self.client.embeddings.create(input=texts, model=model)
return [embedding.embedding for embedding in response.data]
def _embed_cohere(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_COHERE_MODEL
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
texts=[text],
texts=texts,
model=model,
input_type=embedding_type,
truncate="END",
)
return response.embeddings[0]
return response.embeddings
def _embed_voyage(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
response = self.client.embed(text, model=model, input_type=embedding_type)
return response.embeddings[0]
# Similar to Cohere, the API server will do approximate size chunking
# it's acceptable to miss by a few tokens
response = self.client.embed(
texts,
model=model,
input_type=embedding_type,
truncation=True, # Also this is default
)
return response.embeddings
def _embed_vertex(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_VERTEX_MODEL
embedding = self.client.get_embeddings(
embeddings = self.client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
]
for text in texts
],
auto_truncate=True, # Also this is default
)
return embedding[0].values
return [embedding.values for embedding in embeddings]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed(
self,
*,
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
) -> list[list[float] | None]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
@staticmethod
def create(
@@ -212,34 +236,83 @@ def embed_text(
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
) -> list[list[float]]:
if provider_type is not None:
prefix: str | None,
) -> list[list[float] | None]:
non_empty_texts = []
empty_indices = []
for idx, text in enumerate(texts):
if text.strip():
non_empty_texts.append(text)
else:
empty_indices.append(idx)
# Third party API based embedding model
if not non_empty_texts:
embeddings = []
elif provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
raise RuntimeError("API key not provided for cloud model")
if prefix:
# This may change in the future if some providers require the user
# to manually append a prefix but this is not the case currently
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.encode(texts, model_name, text_type)
embeddings = cloud_model.embed(
texts=non_empty_texts,
model_name=model_name,
text_type=text_type,
)
elif model_name is not None:
hosted_model = get_embedding_model(
prefixed_texts = (
[f"{prefix}{text}" for text in non_empty_texts]
if prefix
else non_empty_texts
)
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = hosted_model.encode(
texts, normalize_embeddings=normalize_embeddings
embeddings = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
)
else:
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
if embeddings is None:
raise RuntimeError("Embeddings were not created")
raise RuntimeError("Failed to create Embeddings")
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
embeddings_with_nulls: list[list[float] | None] = []
current_embedding_index = 0
for idx in range(len(texts)):
if idx in empty_indices:
embeddings_with_nulls.append(None)
else:
embedding = embeddings[current_embedding_index]
if isinstance(embedding, list) or embedding is None:
embeddings_with_nulls.append(embedding)
else:
embeddings_with_nulls.append(embedding.tolist())
current_embedding_index += 1
embeddings = embeddings_with_nulls
return embeddings
@simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
@@ -252,7 +325,17 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
async def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
try:
if embed_request.text_type == EmbedTextType.QUERY:
prefix = embed_request.manual_query_prefix
elif embed_request.text_type == EmbedTextType.PASSAGE:
prefix = embed_request.manual_passage_prefix
else:
prefix = None
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
@@ -261,13 +344,13 @@ async def process_embed_request(
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
logger.exception(f"Error during embedding process:\n{str(e)}")
raise HTTPException(
status_code=500, detail="Failed to run Bi-Encoder embedding"
)
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
@router.post("/cross-encoder-scores")
@@ -276,6 +359,11 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
if not embed_request.documents or not embed_request.query:
raise HTTPException(
status_code=400, detail="No documents or query to be reranked"
)
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents

View File

@@ -1,6 +1,7 @@
fastapi==0.109.2
h5py==3.9.0
pydantic==1.10.13
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
tensorflow==2.15.0
@@ -9,5 +10,5 @@ transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
openai==1.14.3
cohere==5.5.8
google-cloud-aiplatform==1.58.0
cohere==5.6.1
google-cloud-aiplatform==1.58.0

View File

@@ -14,7 +14,10 @@ sys.path.append(parent_dir)
# flake8: noqa: E402
# Now import Danswer modules
from danswer.db.models import DocumentSet__ConnectorCredentialPair
from danswer.db.models import (
DocumentSet__ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair,
)
from danswer.db.connector import fetch_connector_by_id
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.index_attempt import (
@@ -44,7 +47,7 @@ logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def unsafe_deletion(
def _unsafe_deletion(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
@@ -82,11 +85,22 @@ def unsafe_deletion(
credential_id=credential_id,
)
# Delete document sets + connector / credential Pairs
# Delete document sets
stmt = delete(DocumentSet__ConnectorCredentialPair).where(
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == pair_id
)
db_session.execute(stmt)
# delete user group associations
stmt = delete(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.cc_pair_id == pair_id
)
db_session.execute(stmt)
# need to flush to avoid foreign key violations
db_session.flush()
# delete the actual connector credential pair
stmt = delete(ConnectorCredentialPair).where(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
@@ -168,7 +182,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
files_deleted_count = unsafe_deletion(
files_deleted_count = _unsafe_deletion(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,

View File

@@ -4,9 +4,7 @@ from shared_configs.enums import EmbedTextType
class EmbedRequest(BaseModel):
# This already includes any prefixes, the text is just passed directly to the model
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None
max_context_length: int
@@ -14,10 +12,12 @@ class EmbedRequest(BaseModel):
api_key: str | None
provider_type: str | None
text_type: EmbedTextType
manual_query_prefix: str | None
manual_passage_prefix: str | None
class EmbedResponse(BaseModel):
embeddings: list[list[float]]
embeddings: list[list[float] | None]
class RerankRequest(BaseModel):
@@ -26,7 +26,7 @@ class RerankRequest(BaseModel):
class RerankResponse(BaseModel):
scores: list[list[float]]
scores: list[list[float] | None]
class IntentRequest(BaseModel):

View File

@@ -9,7 +9,7 @@ This Python script automates the process of running search quality tests for a b
- Manages environment variables
- Switches to specified Git branch
- Uploads test documents
- Runs search quality tests using Relari
- Runs search quality tests
- Cleans up Docker containers (optional)
## Usage
@@ -29,9 +29,17 @@ export PYTHONPATH=$PYTHONPATH:$PWD/backend
```
cd backend/tests/regression/answer_quality
```
7. Run the script:
7. To launch the evaluation environment, run the launch_eval_env.py script (this step can be skipped if you are running the env outside of docker, just leave "environment_name" blank):
```
python run_eval_pipeline.py
python launch_eval_env.py
```
8. Run the file_uploader.py script to upload the zip files located at the path "zipped_documents_file"
```
python file_uploader.py
```
9. Run the run_qa.py script to ask questions from the jsonl located at the path "questions_file". This will hit the "query/answer-with-quote" API endpoint.
```
python run_qa.py
```
Note: All data will be saved even after the containers are shut down. There are instructions below to re-launching docker containers using this data.
@@ -61,6 +69,11 @@ Edit `search_test_config.yaml` to set:
- Set this to true to automatically delete all docker containers, networks and volumes after the test
- launch_web_ui
- Set this to true if you want to use the UI during/after the testing process
- only_state
- Whether to only run Vespa and Postgres
- only_retrieve_docs
- Set true to only retrieve documents, not LLM response
- This is to save on API costs
- use_cloud_gpu
- Set to true or false depending on if you want to use the remote gpu
- Only need to set this if use_cloud_gpu is true
@@ -70,12 +83,10 @@ Edit `search_test_config.yaml` to set:
- model_server_port
- This is the port of the remote model server
- Only need to set this if use_cloud_gpu is true
- existing_test_suffix (THIS IS NOT A SUFFIX ANYMORE, TODO UPDATE THE DOCS HERE)
- environment_name
- Use this if you would like to relaunch a previous test instance
- Input the suffix of the test you'd like to re-launch
- (E.g. to use the data from folder "test-1234-5678" put "-1234-5678")
- No new files will automatically be uploaded
- Leave empty to run a new test
- Input the env_name of the test you'd like to re-launch
- Leave empty to launch referencing local default network locations
- limit
- Max number of questions you'd like to ask against the dataset
- Set to null for no limit
@@ -85,7 +96,7 @@ Edit `search_test_config.yaml` to set:
## Relaunching From Existing Data
To launch an existing set of containers that has already completed indexing, set the existing_test_suffix variable. This will launch the docker containers mounted on the volumes of the indicated suffix and will not automatically index any documents or run any QA.
To launch an existing set of containers that has already completed indexing, set the environment_name variable. This will launch the docker containers mounted on the volumes of the indicated env_name and will not automatically index any documents or run any QA.
Once these containers are launched you can run file_uploader.py or run_qa.py (assuming you have run the steps in the Usage section above).
- file_uploader.py will upload and index additional zipped files located at the zipped_documents_file path.

View File

@@ -2,45 +2,31 @@ import requests
from retry import retry
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.connectors.models import InputType
from danswer.db.enums import IndexingStatus
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.models import IndexFilters
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.server.documents.models import ConnectorBase
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
GENERAL_HEADERS = {"Content-Type": "application/json"}
def _api_url_builder(run_suffix: str, api_path: str) -> str:
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
def _create_new_chat_session(run_suffix: str) -> int:
create_chat_request = ChatSessionCreationRequest(
persona_id=0,
description=None,
)
body = create_chat_request.dict()
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
response_json = requests.post(
create_chat_url, headers=GENERAL_HEADERS, json=body
).json()
chat_session_id = response_json.get("chat_session_id")
if isinstance(chat_session_id, int):
return chat_session_id
def _api_url_builder(env_name: str, api_path: str) -> str:
if env_name:
return f"http://localhost:{get_api_server_host_port(env_name)}" + api_path
else:
raise RuntimeError(response_json)
return "http://localhost:8080" + api_path
@retry(tries=10, delay=10)
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
@retry(tries=5, delay=5)
def get_answer_from_query(
query: str, only_retrieve_docs: bool, env_name: str
) -> tuple[list[str], str]:
filters = IndexFilters(
source_type=None,
document_set=None,
@@ -48,42 +34,47 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
tags=None,
access_control_list=None,
)
retrieval_options = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
new_message_request = DirectQARequest(
messages=messages,
prompt_id=0,
persona_id=0,
retrieval_options=RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
),
chain_of_thought=False,
return_contexts=True,
skip_gen_ai_answer_generation=only_retrieve_docs,
)
chat_session_id = _create_new_chat_session(run_suffix)
url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
new_message_request = BasicCreateChatMessageRequest(
chat_session_id=chat_session_id,
message=query,
retrieval_options=retrieval_options,
query_override=query,
)
url = _api_url_builder(env_name, "/query/answer-with-quote/")
headers = {
"Content-Type": "application/json",
}
body = new_message_request.dict()
body["user"] = None
try:
response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
simple_search_docs = response_json.get("simple_search_docs", [])
answer = response_json.get("answer", "")
response_json = requests.post(url, headers=headers, json=body).json()
context_data_list = response_json.get("contexts", {}).get("contexts", [])
answer = response_json.get("answer", "") or ""
except Exception as e:
print("Failed to answer the questions:")
print(f"\t {str(e)}")
print("trying again")
print("Try restarting vespa container and trying agian")
raise e
return simple_search_docs, answer
return context_data_list, answer
@retry(tries=10, delay=10)
def check_if_query_ready(run_suffix: str) -> bool:
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
def check_indexing_status(env_name: str) -> tuple[int, bool]:
url = _api_url_builder(env_name, "/manage/admin/connector/indexing-status/")
try:
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
except Exception as e:
@@ -98,20 +89,21 @@ def check_if_query_ready(run_suffix: str) -> bool:
status = index_attempt["last_status"]
if status == IndexingStatus.IN_PROGRESS or status == IndexingStatus.NOT_STARTED:
ongoing_index_attempts = True
elif status == IndexingStatus.SUCCESS:
doc_count += 16
doc_count += index_attempt["docs_indexed"]
doc_count -= 16
if not doc_count:
print("No docs indexed, waiting for indexing to start")
elif ongoing_index_attempts:
print(
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
)
return doc_count > 0 and not ongoing_index_attempts
# all the +16 and -16 are to account for the fact that the indexing status
# is only updated every 16 documents and will tells us how many are
# chunked, not indexed. probably need to fix this. in the future!
if doc_count:
doc_count += 16
return doc_count, ongoing_index_attempts
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
def run_cc_once(env_name: str, connector_id: int, credential_id: int) -> None:
url = _api_url_builder(env_name, "/manage/admin/connector/run-once/")
body = {
"connector_id": connector_id,
"credential_ids": [credential_id],
@@ -126,9 +118,9 @@ def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
print("Failed text:", response.text)
def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> None:
def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None:
url = _api_url_builder(
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
env_name, f"/manage/connector/{connector_id}/credential/{credential_id}"
)
body = {"name": "zip_folder_contents", "is_public": True}
@@ -141,8 +133,8 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
print("Failed text:", response.text)
def _get_existing_connector_names(run_suffix: str) -> list[str]:
url = _api_url_builder(run_suffix, "/manage/connector")
def _get_existing_connector_names(env_name: str) -> list[str]:
url = _api_url_builder(env_name, "/manage/connector")
body = {
"credential_json": {},
@@ -156,10 +148,10 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
raise RuntimeError(response.__dict__)
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
url = _api_url_builder(run_suffix, "/manage/admin/connector")
def create_connector(env_name: str, file_paths: list[str]) -> int:
url = _api_url_builder(env_name, "/manage/admin/connector")
connector_name = base_connector_name = "search_eval_connector"
existing_connector_names = _get_existing_connector_names(run_suffix)
existing_connector_names = _get_existing_connector_names(env_name)
count = 1
while connector_name in existing_connector_names:
@@ -186,8 +178,8 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
raise RuntimeError(response.__dict__)
def create_credential(run_suffix: str) -> int:
url = _api_url_builder(run_suffix, "/manage/credential")
def create_credential(env_name: str) -> int:
url = _api_url_builder(env_name, "/manage/credential")
body = {
"credential_json": {},
"admin_public": True,
@@ -201,12 +193,12 @@ def create_credential(run_suffix: str) -> int:
@retry(tries=10, delay=2, backoff=2)
def upload_file(run_suffix: str, zip_file_path: str) -> list[str]:
def upload_file(env_name: str, zip_file_path: str) -> list[str]:
files = [
("files", open(zip_file_path, "rb")),
]
api_path = _api_url_builder(run_suffix, "/manage/admin/connector/file/upload")
api_path = _api_url_builder(env_name, "/manage/admin/connector/file/upload")
try:
response = requests.post(api_path, files=files)
response.raise_for_status() # Raises an HTTPError for bad responses

View File

@@ -67,20 +67,20 @@ def switch_to_commit(commit_sha: str) -> None:
print("Repository updated successfully.")
def get_docker_container_env_vars(suffix: str) -> dict:
def get_docker_container_env_vars(env_name: str) -> dict:
"""
Retrieves environment variables from "background" and "api_server" Docker containers.
"""
print(f"Getting environment variables for containers with suffix: {suffix}")
print(f"Getting environment variables for containers with env_name: {env_name}")
combined_env_vars = {}
for container_type in ["background", "api_server"]:
container_name = _run_command(
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{suffix}/'"
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{env_name}/'"
)[0].strip()
if not container_name:
raise RuntimeError(
f"No {container_type} container found with suffix: {suffix}"
f"No {container_type} container found with env_name: {env_name}"
)
env_vars_json = _run_command(
@@ -95,9 +95,9 @@ def get_docker_container_env_vars(suffix: str) -> dict:
return combined_env_vars
def manage_data_directories(suffix: str, base_path: str, use_cloud_gpu: bool) -> None:
def manage_data_directories(env_name: str, base_path: str, use_cloud_gpu: bool) -> None:
# Use the user's home directory as the base path
target_path = os.path.join(os.path.expanduser(base_path), suffix)
target_path = os.path.join(os.path.expanduser(base_path), env_name)
directories = {
"DANSWER_POSTGRES_DATA_DIR": os.path.join(target_path, "postgres/"),
"DANSWER_VESPA_DATA_DIR": os.path.join(target_path, "vespa/"),
@@ -144,12 +144,12 @@ def _is_port_in_use(port: int) -> bool:
def start_docker_compose(
run_suffix: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
env_name: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
) -> None:
print("Starting Docker Compose...")
os.chdir(os.path.dirname(__file__))
os.chdir("../../../../deployment/docker_compose/")
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{run_suffix} up -d"
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{env_name} up -d"
command += " --build"
command += " --force-recreate"
@@ -175,17 +175,17 @@ def start_docker_compose(
print("Containers have been launched")
def cleanup_docker(run_suffix: str) -> None:
def cleanup_docker(env_name: str) -> None:
print(
f"Deleting Docker containers, volumes, and networks for project suffix: {run_suffix}"
f"Deleting Docker containers, volumes, and networks for project env_name: {env_name}"
)
stdout, _ = _run_command("docker ps -a --format '{{json .}}'")
containers = [json.loads(line) for line in stdout.splitlines()]
if not run_suffix:
run_suffix = datetime.now().strftime("-%Y")
project_name = f"danswer-stack{run_suffix}"
if not env_name:
env_name = datetime.now().strftime("-%Y")
project_name = f"danswer-stack{env_name}"
containers_to_delete = [
c for c in containers if c["Names"].startswith(project_name)
]
@@ -221,23 +221,23 @@ def cleanup_docker(run_suffix: str) -> None:
networks = stdout.splitlines()
networks_to_delete = [n for n in networks if run_suffix in n]
networks_to_delete = [n for n in networks if env_name in n]
if not networks_to_delete:
print(f"No networks found containing suffix: {run_suffix}")
print(f"No networks found containing env_name: {env_name}")
else:
network_names = " ".join(networks_to_delete)
_run_command(f"docker network rm {network_names}")
print(
f"Successfully deleted {len(networks_to_delete)} networks containing suffix: {run_suffix}"
f"Successfully deleted {len(networks_to_delete)} networks containing env_name: {env_name}"
)
@retry(tries=5, delay=5, backoff=2)
def get_api_server_host_port(suffix: str) -> str:
def get_api_server_host_port(env_name: str) -> str:
"""
This pulls all containers with the provided suffix
This pulls all containers with the provided env_name
It then grabs the JSON specific container with a name containing "api_server"
It then grabs the port info from the JSON and strips out the relevent data
"""
@@ -248,16 +248,16 @@ def get_api_server_host_port(suffix: str) -> str:
server_jsons = []
for container in containers:
if container_name in container["Names"] and suffix in container["Names"]:
if container_name in container["Names"] and env_name in container["Names"]:
server_jsons.append(container)
if not server_jsons:
raise RuntimeError(
f"No container found containing: {container_name} and {suffix}"
f"No container found containing: {container_name} and {env_name}"
)
elif len(server_jsons) > 1:
raise RuntimeError(
f"Too many containers matching {container_name} found, please indicate a suffix"
f"Too many containers matching {container_name} found, please indicate a env_name"
)
server_json = server_jsons[0]
@@ -278,67 +278,37 @@ def get_api_server_host_port(suffix: str) -> str:
raise RuntimeError(f"Too many ports matching {client_port} found")
if not matching_ports:
raise RuntimeError(
f"No port found containing: {client_port} for container: {container_name} and suffix: {suffix}"
f"No port found containing: {client_port} for container: {container_name} and env_name: {env_name}"
)
return matching_ports[0]
# Added function to check Vespa container health status
def is_vespa_container_healthy(suffix: str) -> bool:
print(f"Checking health status of Vespa container for suffix: {suffix}")
# Find the Vespa container
stdout, _ = _run_command(
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
)
container_name = stdout.strip()
if not container_name:
print(f"No Vespa container found with suffix: {suffix}")
return False
# Get the health status
stdout, _ = _run_command(
f"docker inspect --format='{{{{.State.Health.Status}}}}' {container_name}"
)
health_status = stdout.strip()
is_healthy = health_status.lower() == "healthy"
print(f"Vespa container '{container_name}' health status: {health_status}")
return is_healthy
# Added function to restart Vespa container
def restart_vespa_container(suffix: str) -> None:
print(f"Restarting Vespa container for suffix: {suffix}")
def restart_vespa_container(env_name: str) -> None:
print(f"Restarting Vespa container for env_name: {env_name}")
# Find the Vespa container
stdout, _ = _run_command(
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
f"docker ps -a --format '{{{{.Names}}}}' | awk '/index-1/ && /{env_name}/'"
)
container_name = stdout.strip()
if not container_name:
raise RuntimeError(f"No Vespa container found with suffix: {suffix}")
raise RuntimeError(f"No Vespa container found with env_name: {env_name}")
# Restart the container
_run_command(f"docker restart {container_name}")
print(f"Vespa container '{container_name}' has begun restarting")
time_to_wait = 5
while not is_vespa_container_healthy(suffix):
print(f"Waiting {time_to_wait} seconds for vespa container to restart")
time.sleep(5)
time.sleep(30)
print(f"Vespa container '{container_name}' has been restarted")
if __name__ == "__main__":
"""
Running this just cleans up the docker environment for the container indicated by existing_test_suffix
If no existing_test_suffix is indicated, will just clean up all danswer docker containers/volumes/networks
Running this just cleans up the docker environment for the container indicated by environment_name
If no environment_name is indicated, will just clean up all danswer docker containers/volumes/networks
Note: vespa/postgres mounts are not deleted
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -348,4 +318,4 @@ if __name__ == "__main__":
if not isinstance(config, dict):
raise TypeError("config must be a dictionary")
cleanup_docker(config["existing_test_suffix"])
cleanup_docker(config["environment_name"])

View File

@@ -1,8 +1,13 @@
import os
import tempfile
import time
import zipfile
from pathlib import Path
from types import SimpleNamespace
import yaml
from tests.regression.answer_quality.api_utils import check_indexing_status
from tests.regression.answer_quality.api_utils import create_cc_pair
from tests.regression.answer_quality.api_utils import create_connector
from tests.regression.answer_quality.api_utils import create_credential
@@ -10,15 +15,65 @@ from tests.regression.answer_quality.api_utils import run_cc_once
from tests.regression.answer_quality.api_utils import upload_file
def upload_test_files(zip_file_path: str, run_suffix: str) -> None:
def unzip_and_get_file_paths(zip_file_path: str) -> list[str]:
persistent_dir = tempfile.mkdtemp()
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(persistent_dir)
return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()]
def create_temp_zip_from_files(file_paths: list[str]) -> str:
persistent_dir = tempfile.mkdtemp()
zip_file_path = os.path.join(persistent_dir, "temp.zip")
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
for file_path in file_paths:
zip_file.write(file_path, Path(file_path).name)
return zip_file_path
def upload_test_files(zip_file_path: str, env_name: str) -> None:
print("zip:", zip_file_path)
file_paths = upload_file(run_suffix, zip_file_path)
file_paths = upload_file(env_name, zip_file_path)
conn_id = create_connector(run_suffix, file_paths)
cred_id = create_credential(run_suffix)
conn_id = create_connector(env_name, file_paths)
cred_id = create_credential(env_name)
create_cc_pair(run_suffix, conn_id, cred_id)
run_cc_once(run_suffix, conn_id, cred_id)
create_cc_pair(env_name, conn_id, cred_id)
run_cc_once(env_name, conn_id, cred_id)
def manage_file_upload(zip_file_path: str, env_name: str) -> None:
unzipped_file_paths = unzip_and_get_file_paths(zip_file_path)
total_file_count = len(unzipped_file_paths)
while True:
doc_count, ongoing_index_attempts = check_indexing_status(env_name)
if ongoing_index_attempts:
print(
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
)
elif not doc_count:
print("No docs indexed, waiting for indexing to start")
upload_test_files(zip_file_path, env_name)
elif doc_count < total_file_count:
print(f"No ongooing indexing attempts but only {doc_count} docs indexed")
remaining_files = unzipped_file_paths[doc_count:]
print(f"Grabbed last {len(remaining_files)} docs to try agian")
temp_zip_file_path = create_temp_zip_from_files(remaining_files)
upload_test_files(temp_zip_file_path, env_name)
os.unlink(temp_zip_file_path)
else:
print(f"Successfully uploaded {doc_count} docs!")
break
time.sleep(10)
for file in unzipped_file_paths:
os.unlink(file)
if __name__ == "__main__":
@@ -27,5 +82,5 @@ if __name__ == "__main__":
with open(config_path, "r") as file:
config = SimpleNamespace(**yaml.safe_load(file))
file_location = config.zipped_documents_file
run_suffix = config.existing_test_suffix
upload_test_files(file_location, run_suffix)
env_name = config.environment_name
manage_file_upload(file_location, env_name)

View File

@@ -1,16 +1,12 @@
import os
from datetime import datetime
from types import SimpleNamespace
import yaml
from tests.regression.answer_quality.cli_utils import cleanup_docker
from tests.regression.answer_quality.cli_utils import manage_data_directories
from tests.regression.answer_quality.cli_utils import set_env_variables
from tests.regression.answer_quality.cli_utils import start_docker_compose
from tests.regression.answer_quality.cli_utils import switch_to_commit
from tests.regression.answer_quality.file_uploader import upload_test_files
from tests.regression.answer_quality.run_qa import run_qa_test_and_save_results
def load_config(config_filename: str) -> SimpleNamespace:
@@ -22,12 +18,16 @@ def load_config(config_filename: str) -> SimpleNamespace:
def main() -> None:
config = load_config("search_test_config.yaml")
if config.existing_test_suffix:
run_suffix = config.existing_test_suffix
print("launching danswer with existing data suffix:", run_suffix)
if config.environment_name:
env_name = config.environment_name
print("launching danswer with environment name:", env_name)
else:
run_suffix = datetime.now().strftime("-%Y%m%d-%H%M%S")
print("run_suffix:", run_suffix)
print("No env name defined. Not launching docker.")
print(
"Please define a name in the config yaml to start a new env "
"or use an existing env"
)
return
set_env_variables(
config.model_server_ip,
@@ -35,22 +35,14 @@ def main() -> None:
config.use_cloud_gpu,
config.llm,
)
manage_data_directories(run_suffix, config.output_folder, config.use_cloud_gpu)
manage_data_directories(env_name, config.output_folder, config.use_cloud_gpu)
if config.commit_sha:
switch_to_commit(config.commit_sha)
start_docker_compose(
run_suffix, config.launch_web_ui, config.use_cloud_gpu, config.only_state
env_name, config.launch_web_ui, config.use_cloud_gpu, config.only_state
)
if not config.existing_test_suffix and not config.only_state:
upload_test_files(config.zipped_documents_file, run_suffix)
run_qa_test_and_save_results(run_suffix)
if config.clean_up_docker_containers:
cleanup_docker(run_suffix)
if __name__ == "__main__":
main()

View File

@@ -6,7 +6,6 @@ import time
import yaml
from tests.regression.answer_quality.api_utils import check_if_query_ready
from tests.regression.answer_quality.api_utils import get_answer_from_query
from tests.regression.answer_quality.cli_utils import get_current_commit_sha
from tests.regression.answer_quality.cli_utils import get_docker_container_env_vars
@@ -44,12 +43,12 @@ def _read_questions_jsonl(questions_file_path: str) -> list[dict]:
def _get_test_output_folder(config: dict) -> str:
base_output_folder = os.path.expanduser(config["output_folder"])
if config["run_suffix"]:
if config["env_name"]:
base_output_folder = os.path.join(
base_output_folder, config["run_suffix"], "evaluations_output"
base_output_folder, config["env_name"], "evaluations_output"
)
else:
base_output_folder = os.path.join(base_output_folder, "no_defined_suffix")
base_output_folder = os.path.join(base_output_folder, "no_defined_env_name")
counter = 1
output_folder_path = os.path.join(base_output_folder, "run_1")
@@ -73,12 +72,12 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
metadata = {
"commit_sha": get_current_commit_sha(),
"run_suffix": config["run_suffix"],
"env_name": config["env_name"],
"test_config": config,
"number_of_questions_in_dataset": len(questions),
}
env_vars = get_docker_container_env_vars(config["run_suffix"])
env_vars = get_docker_container_env_vars(config["env_name"])
if env_vars["ENV_SEED_CONFIGURATION"]:
del env_vars["ENV_SEED_CONFIGURATION"]
if env_vars["GPG_KEY"]:
@@ -118,7 +117,8 @@ def _process_question(question_data: dict, config: dict, question_number: int) -
print(f"query: {query}")
context_data_list, answer = get_answer_from_query(
query=query,
run_suffix=config["run_suffix"],
only_retrieve_docs=config["only_retrieve_docs"],
env_name=config["env_name"],
)
if not context_data_list:
@@ -142,27 +142,23 @@ def _process_and_write_query_results(config: dict) -> None:
test_output_folder, questions = _initialize_files(config)
print("saving test results to folder:", test_output_folder)
while not check_if_query_ready(config["run_suffix"]):
time.sleep(5)
if config["limit"] is not None:
questions = questions[: config["limit"]]
with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
# Use multiprocessing to process questions
with multiprocessing.Pool() as pool:
results = pool.starmap(
_process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
_process_question,
[(question, config, i + 1) for i, question in enumerate(questions)],
)
_populate_results_file(test_output_folder, results)
invalid_answer_count = 0
for result in results:
if not result.get("answer"):
if len(result["context_data_list"]) == 0:
invalid_answer_count += 1
if not result.get("context_data_list"):
raise RuntimeError("Search failed, this is a critical failure!")
_update_metadata_file(test_output_folder, invalid_answer_count)
if invalid_answer_count:
@@ -177,7 +173,7 @@ def _process_and_write_query_results(config: dict) -> None:
print("saved test results to folder:", test_output_folder)
def run_qa_test_and_save_results(run_suffix: str = "") -> None:
def run_qa_test_and_save_results(env_name: str = "") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, "search_test_config.yaml")
with open(config_path, "r") as file:
@@ -186,16 +182,16 @@ def run_qa_test_and_save_results(run_suffix: str = "") -> None:
if not isinstance(config, dict):
raise TypeError("config must be a dictionary")
if not run_suffix:
run_suffix = config["existing_test_suffix"]
if not env_name:
env_name = config["environment_name"]
config["run_suffix"] = run_suffix
config["env_name"] = env_name
_process_and_write_query_results(config)
if __name__ == "__main__":
"""
To run a different set of questions, update the questions_file in search_test_config.yaml
If there is more than one instance of Danswer running, specify the suffix in search_test_config.yaml
If there is more than one instance of Danswer running, specify the env_name in search_test_config.yaml
"""
run_qa_test_and_save_results()

View File

@@ -13,14 +13,11 @@ questions_file: "~/sample_questions.yaml"
# Git commit SHA to use (null means use current code as is)
commit_sha: null
# Whether to remove Docker containers after the test
clean_up_docker_containers: true
# Whether to launch a web UI for the test
launch_web_ui: false
# Whether to only run Vespa and Postgres
only_state: false
# Only retrieve documents, not LLM response
only_retrieve_docs: false
# Whether to use a cloud GPU for processing
use_cloud_gpu: false
@@ -31,9 +28,8 @@ model_server_ip: "PUT_PUBLIC_CLOUD_IP_HERE"
# Port of the model server (placeholder)
model_server_port: "PUT_PUBLIC_CLOUD_PORT_HERE"
# Suffix for existing test results (E.g. -1234-5678)
# empty string means no suffix
existing_test_suffix: ""
# Name for existing testing env (empty string uses default ports)
environment_name: ""
# Limit on number of tests to run (null means no limit)
limit: null

View File

@@ -36,8 +36,7 @@ export function ModelSelectionConfirmationModal({
at least 16GB of RAM to Danswer during this process.
</Text>
{/* TODO Change this back- ensure functional */}
{!isCustom && (
{isCustom && (
<Callout title="IMPORTANT" color="yellow" className="mt-4">
We&apos;ve detected that this is a custom-specified embedding
model. Since we have to download the model files before verifying

View File

@@ -149,6 +149,7 @@ function Main() {
}
);
if (response.ok) {
setShowTentativeOpenProvider(null);
setShowTentativeModel(null);
mutate("/api/secondary-index/get-secondary-embedding-model");
if (!connectors || !connectors.length) {
@@ -274,6 +275,7 @@ function Main() {
onClose={() => setAlreadySelectedModel(null)}
/>
)}
{showTentativeOpenProvider && (
<ModelSelectionConfirmationModal
selectedModel={showTentativeOpenProvider}

View File

@@ -1,5 +1,6 @@
import { LoadingAnimation } from "@/components/Loading";
import { Button, Divider, Text } from "@tremor/react";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import {
ArrayHelpers,
ErrorMessage,
@@ -15,11 +16,16 @@ import {
SubLabel,
TextArrayField,
TextFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr";
import { useUserGroups } from "@/lib/hooks";
import { FullLLMProvider } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@@ -44,9 +50,16 @@ export function CustomLLMProviderUpdateForm({
}) {
const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
@@ -61,6 +74,8 @@ export function CustomLLMProviderUpdateForm({
custom_config_list: existingLlmProvider?.custom_config
? Object.entries(existingLlmProvider.custom_config)
: [],
is_public: existingLlmProvider?.is_public ?? true,
groups: existingLlmProvider?.groups ?? [],
};
// Setup validation schema if required
@@ -74,6 +89,9 @@ export function CustomLLMProviderUpdateForm({
default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(),
custom_config_list: Yup.array(),
// EE Only
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
});
return (
@@ -97,6 +115,9 @@ export function CustomLLMProviderUpdateForm({
return;
}
// don't set groups if marked as public
const groups = values.is_public ? [] : values.groups;
// test the configuration
if (!isEqual(values, initialValues)) {
setIsTesting(true);
@@ -188,93 +209,97 @@ export function CustomLLMProviderUpdateForm({
setSubmitting(false);
}}
>
{({ values }) => (
<Form>
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
/>
{({ values, setFieldValue }) => {
return (
<Form>
<TextFormField
name="name"
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
/>
<Divider />
<Divider />
<TextFormField
name="provider"
label="Provider Name"
subtext={
<TextFormField
name="provider"
label="Provider Name"
subtext={
<>
Should be one of the providers listed at{" "}
<a
target="_blank"
href="https://docs.litellm.ai/docs/providers"
className="text-link"
>
https://docs.litellm.ai/docs/providers
</a>
.
</>
}
placeholder="Name of the custom provider"
/>
<Divider />
<SubLabel>
Fill in the following as is needed. Refer to the LiteLLM
documentation for the model provider name specified above in order
to determine which fields are required.
</SubLabel>
<TextFormField
name="api_key"
label="[Optional] API Key"
placeholder="API Key"
type="password"
/>
<TextFormField
name="api_base"
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<>
Should be one of the providers listed at{" "}
<a
target="_blank"
href="https://docs.litellm.ai/docs/providers"
className="text-link"
>
https://docs.litellm.ai/docs/providers
</a>
.
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
</div>
<div className="mt-2">
For example, when configuring the Cloudflare provider, you
would need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
Cloudflare account ID as the value.
</div>
</>
}
placeholder="Name of the custom provider"
/>
</SubLabel>
<Divider />
<SubLabel>
Fill in the following as is needed. Refer to the LiteLLM
documentation for the model provider name specified above in order
to determine which fields are required.
</SubLabel>
<TextFormField
name="api_key"
label="[Optional] API Key"
placeholder="API Key"
type="password"
/>
<TextFormField
name="api_base"
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<>
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
</div>
<div className="mt-2">
For example, when configuring the Cloudflare provider, you would
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
Cloudflare account ID as the value.
</div>
</>
</SubLabel>
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
{values.custom_config_list.map((_, index) => {
return (
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
<Field
name={`custom_config_list[${index}][0]`}
className={`
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
{values.custom_config_list.map((_, index) => {
return (
<div
key={index}
className={index === 0 ? "mt-2" : "mt-6"}
>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
<Field
name={`custom_config_list[${index}][0]`}
className={`
border
border-border
bg-background
@@ -284,20 +309,20 @@ export function CustomLLMProviderUpdateForm({
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][0]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][0]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label>Value</Label>
<Field
name={`custom_config_list[${index}][1]`}
className={`
<div className="mt-3">
<Label>Value</Label>
<Field
name={`custom_config_list[${index}][1]`}
className={`
border
border-border
bg-background
@@ -307,121 +332,190 @@ export function CustomLLMProviderUpdateForm({
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][1]`}
component="div"
className="text-error text-sm mt-1"
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][1]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
</div>
);
})}
);
})}
<Button
onClick={() => {
arrayHelpers.push(["", ""]);
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
<Button
onClick={() => {
arrayHelpers.push(["", ""]);
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
<Divider />
<Divider />
<TextArrayField
name="model_names"
label="Model Names"
values={values}
subtext={`List the individual models that you want to make
<TextArrayField
name="model_names"
label="Model Names"
values={values}
subtext={`List the individual models that you want to make
available as a part of this provider. At least one must be specified.
As an example, for OpenAI one model might be "gpt-4".`}
/>
/>
<Divider />
<Divider />
<TextFormField
name="default_model_name"
subtext={`
<TextFormField
name="default_model_name"
subtext={`
The model to use by default for this provider unless
otherwise specified. Must be one of the models listed
above.`}
label="Default Model"
placeholder="E.g. gpt-4"
/>
label="Default Model"
placeholder="E.g. gpt-4"
/>
<TextFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
<TextFormField
name="fast_default_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If not set, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
<Divider />
<Divider />
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && <Text className="text-error mt-2">{testError}</Text>}
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
<div className="flex w-full mt-4">
<Button type="submit" size="xs">
{isTesting ? (
<LoadingAnimation text="Testing" />
) : existingLlmProvider ? (
"Update"
) : (
"Enable"
{showAdvancedOptions && (
<>
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
noPadding
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups &&
userGroups.length > 0 &&
!values.is_public && (
<div>
<Text>
Select which User Groups should have access to this
LLM Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)}
</Button>
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
</>
)}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && (
<Text className="text-error mt-2">{testError}</Text>
)}
<div className="flex w-full mt-4">
<Button type="submit" size="xs">
{isTesting ? (
<LoadingAnimation text="Testing" />
) : existingLlmProvider ? (
"Update"
) : (
"Enable"
)}
</Button>
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
)}
</div>
</div>
</div>
</Form>
)}
</Form>
);
}}
</Formik>
);
}

View File

@@ -1,4 +1,5 @@
import { LoadingAnimation } from "@/components/Loading";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { Button, Divider, Text } from "@tremor/react";
import { Form, Formik } from "formik";
import { FiTrash } from "react-icons/fi";
@@ -6,11 +7,16 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import {
SelectorFormField,
TextFormField,
BooleanFormField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr";
import { useUserGroups } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@@ -29,9 +35,16 @@ export function LLMProviderUpdateForm({
}) {
const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
@@ -54,6 +67,8 @@ export function LLMProviderUpdateForm({
},
{} as { [key: string]: string }
),
is_public: existingLlmProvider?.is_public ?? true,
groups: existingLlmProvider?.groups ?? [],
};
const [validatedConfig, setValidatedConfig] = useState(
@@ -91,6 +106,9 @@ export function LLMProviderUpdateForm({
: {}),
default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(),
// EE Only
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
});
return (
@@ -193,7 +211,7 @@ export function LLMProviderUpdateForm({
setSubmitting(false);
}}
>
{({ values }) => (
{({ values, setFieldValue }) => (
<Form>
<TextFormField
name="name"
@@ -293,6 +311,69 @@ export function LLMProviderUpdateForm({
<Divider />
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
setShowAdvancedOptions={setShowAdvancedOptions}
/>
{showAdvancedOptions && (
<>
{isPaidEnterpriseFeaturesEnabled && userGroups && (
<>
<BooleanFormField
small
noPadding
alignTop
name="is_public"
label="Is Public?"
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
/>
{userGroups && userGroups.length > 0 && !values.is_public && (
<div>
<Text>
Select which User Groups should have access to this LLM
Provider.
</Text>
<div className="flex flex-wrap gap-2 mt-2">
{userGroups.map((userGroup) => {
const isSelected = values.groups.includes(
userGroup.id
);
return (
<Bubble
key={userGroup.id}
isSelected={isSelected}
onClick={() => {
if (isSelected) {
setFieldValue(
"groups",
values.groups.filter(
(id) => id !== userGroup.id
)
);
} else {
setFieldValue("groups", [
...values.groups,
userGroup.id,
]);
}
}}
>
<div className="flex">
<GroupsIcon />
<div className="ml-1">{userGroup.name}</div>
</div>
</Bubble>
);
})}
</div>
</div>
)}
</>
)}
</>
)}
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{testError && <Text className="text-error mt-2">{testError}</Text>}

View File

@@ -17,6 +17,8 @@ export interface WellKnownLLMProviderDescriptor {
llm_names: string[];
default_model: string | null;
default_fast_model: string | null;
is_public: boolean;
groups: number[];
}
export interface LLMProvider {
@@ -28,6 +30,8 @@ export interface LLMProvider {
custom_config: { [key: string]: string } | null;
default_model_name: string;
fast_default_model_name: string | null;
is_public: boolean;
groups: number[];
}
export interface FullLLMProvider extends LLMProvider {
@@ -44,4 +48,6 @@ export interface LLMProviderDescriptor {
default_model_name: string;
fast_default_model_name: string | null;
is_default_provider: boolean | null;
is_public: boolean;
groups: number[];
}

View File

@@ -50,7 +50,7 @@ export default async function GalleryPage({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availablePersonas: assistants,
availableAssistants: assistants,
availableTags: tags,
llmProviders,
folders,

View File

@@ -52,7 +52,7 @@ export default async function GalleryPage({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availablePersonas: assistants,
availableAssistants: assistants,
availableTags: tags,
llmProviders,
folders,

View File

@@ -14,15 +14,18 @@ export function ChatBanner() {
return (
<div
className={`
z-[39]
h-[30px]
bg-background-100
shadow-sm
m-2
rounded
border-border
border
flex`}
mt-8
mb-2
mx-2
z-[39]
w-full
h-[30px]
bg-background-100
shadow-sm
rounded
border-border
border
flex`}
>
<div className="mx-auto text-emphasis text-sm flex flex-col">
<div className="my-auto">

View File

@@ -63,7 +63,6 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
import { ConfigurationModal } from "./modal/configuration/ConfigurationModal";
import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid";
import { orderAssistantsForUser } from "@/lib/assistants/orderAssistants";
@@ -82,38 +81,29 @@ const SYSTEM_MESSAGE_ID = -3;
export function ChatPage({
toggle,
documentSidebarInitialWidth,
defaultSelectedPersonaId,
defaultSelectedAssistantId,
toggledSidebar,
}: {
toggle: () => void;
documentSidebarInitialWidth?: number;
defaultSelectedPersonaId?: number;
defaultSelectedAssistantId?: number;
toggledSidebar: boolean;
}) {
const [configModalActiveTab, setConfigModalActiveTab] = useState<
string | null
>(null);
const router = useRouter();
const searchParams = useSearchParams();
let {
user,
chatSessions,
availableSources,
availableDocumentSets,
availablePersonas,
availableAssistants,
llmProviders,
folders,
openedFolders,
} = useChatContext();
const filteredAssistants = orderAssistantsForUser(availablePersonas, user);
const [selectedAssistant, setSelectedAssistant] = useState<Persona | null>(
null
);
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
useState<Persona | null>(null);
const router = useRouter();
const searchParams = useSearchParams();
// chat session
const existingChatIdRaw = searchParams.get("chatId");
const existingChatSessionId = existingChatIdRaw
? parseInt(existingChatIdRaw)
@@ -123,9 +113,51 @@ export function ChatPage({
);
const chatSessionIdRef = useRef<number | null>(existingChatSessionId);
// LLM
const llmOverrideManager = useLlmOverride(selectedChatSession);
const existingChatSessionPersonaId = selectedChatSession?.persona_id;
// Assistants
const filteredAssistants = orderAssistantsForUser(availableAssistants, user);
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
const [selectedAssistant, setSelectedAssistant] = useState<
Persona | undefined
>(
// NOTE: look through available assistants here, so that even if the user
// has hidden this assistant it still shows the correct assistant when
// going back to an old chat session
existingChatSessionAssistantId !== undefined
? availableAssistants.find(
(assistant) => assistant.id === existingChatSessionAssistantId
)
: defaultSelectedAssistantId !== undefined
? availableAssistants.find(
(assistant) => assistant.id === defaultSelectedAssistantId
)
: undefined
);
const setSelectedAssistantFromId = (assistantId: number) => {
// NOTE: also intentionally look through available assistants here, so that
// even if the user has hidden an assistant they can still go back to it
// for old chats
setSelectedAssistant(
availableAssistants.find((assistant) => assistant.id === assistantId)
);
};
const liveAssistant =
selectedAssistant || filteredAssistants[0] || availableAssistants[0];
// this is for "@"ing assistants
const [alternativeAssistant, setAlternativeAssistant] =
useState<Persona | null>(null);
// this is used to track which assistant is being used to generate the current message
// for example, this would come into play when:
// 1. default assistant is `Danswer`
// 2. we "@"ed the `GPT` assistant and sent a message
// 3. while the `GPT` assistant message is generating, we "@" the `Paraphrase` assistant
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
useState<Persona | null>(null);
// used to track whether or not the initial "submit on load" has been performed
// this only applies if `?submit-on-load=true` or `?submit-on-load=1` is in the URL
@@ -182,14 +214,10 @@ export function ChatPage({
async function initialSessionFetch() {
if (existingChatSessionId === null) {
setIsFetchingChatMessages(false);
if (defaultSelectedPersonaId !== undefined) {
setSelectedPersona(
filteredAssistants.find(
(persona) => persona.id === defaultSelectedPersonaId
)
);
if (defaultSelectedAssistantId !== undefined) {
setSelectedAssistantFromId(defaultSelectedAssistantId);
} else {
setSelectedPersona(undefined);
setSelectedAssistant(undefined);
}
setCompleteMessageDetail({
sessionId: null,
@@ -214,12 +242,7 @@ export function ChatPage({
);
const chatSession = (await response.json()) as BackendChatSession;
setSelectedPersona(
filteredAssistants.find(
(persona) => persona.id === chatSession.persona_id
)
);
setSelectedAssistantFromId(chatSession.persona_id);
const newMessageMap = processRawChatHistory(chatSession.messages);
const newMessageHistory = buildLatestMessageChain(newMessageMap);
@@ -373,32 +396,18 @@ export function ChatPage({
)
: { aiMessage: null };
const [selectedPersona, setSelectedPersona] = useState<Persona | undefined>(
existingChatSessionPersonaId !== undefined
? filteredAssistants.find(
(persona) => persona.id === existingChatSessionPersonaId
)
: defaultSelectedPersonaId !== undefined
? filteredAssistants.find(
(persona) => persona.id === defaultSelectedPersonaId
)
: undefined
);
const livePersona =
selectedPersona || filteredAssistants[0] || availablePersonas[0];
const [chatSessionSharedStatus, setChatSessionSharedStatus] =
useState<ChatSessionSharedStatus>(ChatSessionSharedStatus.Private);
useEffect(() => {
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
setSelectedPersona(
setSelectedAssistant(
filteredAssistants.find(
(persona) => persona.id === defaultSelectedPersonaId
(persona) => persona.id === defaultSelectedAssistantId
)
);
}
}, [defaultSelectedPersonaId]);
}, [defaultSelectedAssistantId]);
const [
selectedDocuments,
@@ -414,7 +423,7 @@ export function ChatPage({
useEffect(() => {
async function fetchMaxTokens() {
const response = await fetch(
`/api/chat/max-selected-document-tokens?persona_id=${livePersona.id}`
`/api/chat/max-selected-document-tokens?persona_id=${liveAssistant.id}`
);
if (response.ok) {
const maxTokens = (await response.json()).max_tokens as number;
@@ -423,12 +432,12 @@ export function ChatPage({
}
fetchMaxTokens();
}, [livePersona]);
}, [liveAssistant]);
const filterManager = useFilters();
const [finalAvailableSources, finalAvailableDocumentSets] =
computeAvailableFilters({
selectedPersona,
selectedPersona: selectedAssistant,
availableSources,
availableDocumentSets,
});
@@ -624,16 +633,16 @@ export function ChatPage({
queryOverride,
forceSearch,
isSeededChat,
alternativeAssistant = null,
alternativeAssistantOverride = null,
}: {
messageIdToResend?: number;
messageOverride?: string;
queryOverride?: string;
forceSearch?: boolean;
isSeededChat?: boolean;
alternativeAssistant?: Persona | null;
alternativeAssistantOverride?: Persona | null;
} = {}) => {
setAlternativeGeneratingAssistant(alternativeAssistant);
setAlternativeGeneratingAssistant(alternativeAssistantOverride);
clientScrollToBottom();
let currChatSessionId: number;
@@ -643,7 +652,7 @@ export function ChatPage({
if (isNewSession) {
currChatSessionId = await createChatSession(
livePersona?.id || 0,
liveAssistant?.id || 0,
searchParamBasedChatSessionName
);
} else {
@@ -721,9 +730,9 @@ export function ChatPage({
parentMessage = frozenMessageMap.get(SYSTEM_MESSAGE_ID) || null;
}
const currentAssistantId = alternativeAssistant
? alternativeAssistant.id
: selectedAssistant?.id;
const currentAssistantId = alternativeAssistantOverride
? alternativeAssistantOverride.id
: alternativeAssistant?.id || liveAssistant.id;
resetInputBar();
@@ -751,7 +760,7 @@ export function ChatPage({
fileDescriptors: currentMessageFiles,
parentMessageId: lastSuccessfulMessageId,
chatSessionId: currChatSessionId,
promptId: livePersona?.prompts[0]?.id || 0,
promptId: liveAssistant?.prompts[0]?.id || 0,
filters: buildFilters(
filterManager.selectedSources,
filterManager.selectedDocumentSets,
@@ -868,7 +877,7 @@ export function ChatPage({
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
parentMessageId: newUserMessageId,
alternateAssistantID: selectedAssistant?.id,
alternateAssistantID: alternativeAssistant?.id,
},
]);
}
@@ -964,19 +973,23 @@ export function ChatPage({
}
};
const onPersonaChange = (persona: Persona | null) => {
if (persona && persona.id !== livePersona.id) {
const onAssistantChange = (assistant: Persona | null) => {
if (assistant && assistant.id !== liveAssistant.id) {
// remove uploaded files
setCurrentMessageFiles([]);
setSelectedPersona(persona);
setSelectedAssistant(assistant);
textAreaRef.current?.focus();
router.push(buildChatUrl(searchParams, null, persona.id));
router.push(buildChatUrl(searchParams, null, assistant.id));
}
};
const handleImageUpload = (acceptedFiles: File[]) => {
const llmAcceptsImages = checkLLMSupportsImageInput(
...getFinalLLM(llmProviders, livePersona, llmOverrideManager.llmOverride)
...getFinalLLM(
llmProviders,
liveAssistant,
llmOverrideManager.llmOverride
)
);
const imageFiles = acceptedFiles.filter((file) =>
file.type.startsWith("image/")
@@ -1058,23 +1071,23 @@ export function ChatPage({
useEffect(() => {
const includes = checkAnyAssistantHasSearch(
messageHistory,
availablePersonas,
livePersona
availableAssistants,
liveAssistant
);
setRetrievalEnabled(includes);
}, [messageHistory, availablePersonas, livePersona]);
}, [messageHistory, availableAssistants, liveAssistant]);
const [retrievalEnabled, setRetrievalEnabled] = useState(() => {
return checkAnyAssistantHasSearch(
messageHistory,
availablePersonas,
livePersona
availableAssistants,
liveAssistant
);
});
const innerSidebarElementRef = useRef<HTMLDivElement>(null);
const currentPersona = selectedAssistant || livePersona;
const currentPersona = alternativeAssistant || liveAssistant;
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
@@ -1176,21 +1189,8 @@ export function ChatPage({
/>
)}
<ConfigurationModal
chatSessionId={chatSessionIdRef.current!}
activeTab={configModalActiveTab}
setActiveTab={setConfigModalActiveTab}
onClose={() => setConfigModalActiveTab(null)}
filterManager={filterManager}
availableAssistants={filteredAssistants}
selectedAssistant={livePersona}
setSelectedAssistant={onPersonaChange}
llmProviders={llmProviders}
llmOverrideManager={llmOverrideManager}
/>
<div className="flex h-[calc(100dvh)] flex-col w-full">
{livePersona && (
{liveAssistant && (
<FunctionalHeader
page="chat"
setSharingModalVisible={
@@ -1203,6 +1203,23 @@ export function ChatPage({
currentChatSession={selectedChatSession}
/>
)}
<div className="w-full flex">
<div
style={{ transition: "width 0.30s ease-out" }}
className={`
flex-none
overflow-y-hidden
bg-background-100
transition-all
bg-opacity-80
duration-300
ease-in-out
h-full
${toggledSidebar || showDocSidebar ? "w-[300px]" : "w-[0px]"}
`}
/>
<ChatBanner />
</div>
{documentSidebarInitialWidth !== undefined ? (
<Dropzone onDrop={handleImageUpload} noClick>
{({ getRootProps }) => (
@@ -1231,15 +1248,14 @@ export function ChatPage({
ref={scrollableDivRef}
>
{/* ChatBanner is a custom banner that displays a admin-specified message at
the top of the chat page. Only used in the EE version of the app. */}
<ChatBanner />
the top of the chat page. Oly used in the EE version of the app. */}
{messageHistory.length === 0 &&
!isFetchingChatMessages &&
!isStreaming && (
<ChatIntro
availableSources={finalAvailableSources}
selectedPersona={livePersona}
selectedPersona={liveAssistant}
/>
)}
<div
@@ -1319,7 +1335,7 @@ export function ChatPage({
const currentAlternativeAssistant =
message.alternateAssistantID != null
? availablePersonas.find(
? availableAssistants.find(
(persona) =>
persona.id ==
message.alternateAssistantID
@@ -1342,7 +1358,7 @@ export function ChatPage({
toggleDocumentSelectionAspects
}
docs={message.documents}
currentPersona={livePersona}
currentPersona={liveAssistant}
alternativeAssistant={
currentAlternativeAssistant
}
@@ -1352,7 +1368,7 @@ export function ChatPage({
query={
messageHistory[i]?.query || undefined
}
personaName={livePersona.name}
personaName={liveAssistant.name}
citedDocuments={getCitedDocumentsFromMessage(
message
)}
@@ -1404,7 +1420,7 @@ export function ChatPage({
messageIdToResend:
previousMessage.messageId,
queryOverride: newQuery,
alternativeAssistant:
alternativeAssistantOverride:
currentAlternativeAssistant,
});
}
@@ -1435,7 +1451,7 @@ export function ChatPage({
messageIdToResend:
previousMessage.messageId,
forceSearch: true,
alternativeAssistant:
alternativeAssistantOverride:
currentAlternativeAssistant,
});
} else {
@@ -1460,9 +1476,9 @@ export function ChatPage({
return (
<div key={messageReactComponentKey}>
<AIMessage
currentPersona={livePersona}
currentPersona={liveAssistant}
messageId={message.messageId}
personaName={livePersona.name}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
@@ -1481,13 +1497,13 @@ export function ChatPage({
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
>
<AIMessage
currentPersona={livePersona}
currentPersona={liveAssistant}
alternativeAssistant={
alternativeGeneratingAssistant ??
selectedAssistant
alternativeAssistant
}
messageId={null}
personaName={livePersona.name}
personaName={liveAssistant.name}
content={
<div className="text-sm my-auto">
<ThreeDots
@@ -1513,7 +1529,7 @@ export function ChatPage({
{currentPersona &&
currentPersona.starter_messages &&
currentPersona.starter_messages.length > 0 &&
selectedPersona &&
selectedAssistant &&
messageHistory.length === 0 &&
!isFetchingChatMessages && (
<div
@@ -1570,32 +1586,25 @@ export function ChatPage({
<ChatInputBar
showDocs={() => setDocumentSelection(true)}
selectedDocuments={selectedDocuments}
setSelectedAssistant={onPersonaChange}
onSetSelectedAssistant={(
alternativeAssistant: Persona | null
) => {
setSelectedAssistant(alternativeAssistant);
}}
alternativeAssistant={selectedAssistant}
personas={filteredAssistants}
// assistant stuff
assistantOptions={filteredAssistants}
selectedAssistant={liveAssistant}
setSelectedAssistant={onAssistantChange}
setAlternativeAssistant={setAlternativeAssistant}
alternativeAssistant={alternativeAssistant}
// end assistant stuff
message={message}
setMessage={setMessage}
onSubmit={onSubmit}
isStreaming={isStreaming}
setIsCancelled={setIsCancelled}
retrievalDisabled={
!personaIncludesRetrieval(currentPersona)
}
filterManager={filterManager}
llmOverrideManager={llmOverrideManager}
selectedAssistant={livePersona}
files={currentMessageFiles}
setFiles={setCurrentMessageFiles}
handleFileUpload={handleImageUpload}
setConfigModalActiveTab={setConfigModalActiveTab}
textAreaRef={textAreaRef}
chatSessionId={chatSessionIdRef.current!}
availableAssistants={availablePersonas}
/>
</div>
</div>

View File

@@ -5,10 +5,10 @@ import { ChatPage } from "./ChatPage";
import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper";
export default function WrappedChat({
defaultPersonaId,
defaultAssistantId,
initiallyToggled,
}: {
defaultPersonaId?: number;
defaultAssistantId?: number;
initiallyToggled: boolean;
}) {
return (
@@ -17,7 +17,7 @@ export default function WrappedChat({
content={(toggledSidebar, toggle) => (
<ChatPage
toggle={toggle}
defaultSelectedPersonaId={defaultPersonaId}
defaultSelectedAssistantId={defaultAssistantId}
toggledSidebar={toggledSidebar}
/>
)}

View File

@@ -21,7 +21,6 @@ import { IconType } from "react-icons";
import Popup from "../../../components/popup/Popup";
import { LlmTab } from "../modal/configuration/LlmTab";
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
import ChatInputAssistant from "./ChatInputAssistant";
import { DanswerDocument } from "@/lib/search/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Tooltip } from "@/components/tooltip/Tooltip";
@@ -29,7 +28,6 @@ import { Hoverable } from "@/components/Hoverable";
const MAX_INPUT_HEIGHT = 200;
export function ChatInputBar({
personas,
showDocs,
selectedDocuments,
message,
@@ -37,34 +35,32 @@ export function ChatInputBar({
onSubmit,
isStreaming,
setIsCancelled,
retrievalDisabled,
filterManager,
llmOverrideManager,
onSetSelectedAssistant,
selectedAssistant,
files,
// assistants
selectedAssistant,
assistantOptions,
setSelectedAssistant,
setAlternativeAssistant,
files,
setFiles,
handleFileUpload,
setConfigModalActiveTab,
textAreaRef,
alternativeAssistant,
chatSessionId,
availableAssistants,
}: {
showDocs: () => void;
selectedDocuments: DanswerDocument[];
availableAssistants: Persona[];
onSetSelectedAssistant: (alternativeAssistant: Persona | null) => void;
assistantOptions: Persona[];
setAlternativeAssistant: (alternativeAssistant: Persona | null) => void;
setSelectedAssistant: (assistant: Persona) => void;
personas: Persona[];
message: string;
setMessage: (message: string) => void;
onSubmit: () => void;
isStreaming: boolean;
setIsCancelled: (value: boolean) => void;
retrievalDisabled: boolean;
filterManager: FilterManager;
llmOverrideManager: LlmOverrideManager;
selectedAssistant: Persona;
@@ -72,7 +68,6 @@ export function ChatInputBar({
files: FileDescriptor[];
setFiles: (files: FileDescriptor[]) => void;
handleFileUpload: (files: File[]) => void;
setConfigModalActiveTab: (tab: string) => void;
textAreaRef: React.RefObject<HTMLTextAreaElement>;
chatSessionId?: number;
}) {
@@ -136,8 +131,10 @@ export function ChatInputBar({
};
// Update selected persona
const updateCurrentPersona = (persona: Persona) => {
onSetSelectedAssistant(persona.id == selectedAssistant.id ? null : persona);
const updatedTaggedAssistant = (assistant: Persona) => {
setAlternativeAssistant(
assistant.id == selectedAssistant.id ? null : assistant
);
hideSuggestions();
setMessage("");
};
@@ -160,8 +157,8 @@ export function ChatInputBar({
}
};
const filteredPersonas = personas.filter((persona) =>
persona.name.toLowerCase().startsWith(
const assistantTagOptions = assistantOptions.filter((assistant) =>
assistant.name.toLowerCase().startsWith(
message
.slice(message.lastIndexOf("@") + 1)
.split(/\s/)[0]
@@ -174,18 +171,18 @@ export function ChatInputBar({
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (
showSuggestions &&
filteredPersonas.length > 0 &&
assistantTagOptions.length > 0 &&
(e.key === "Tab" || e.key == "Enter")
) {
e.preventDefault();
if (assistantIconIndex == filteredPersonas.length) {
if (assistantIconIndex == assistantTagOptions.length) {
window.open("/assistants/new", "_blank");
hideSuggestions();
setMessage("");
} else {
const option =
filteredPersonas[assistantIconIndex >= 0 ? assistantIconIndex : 0];
updateCurrentPersona(option);
assistantTagOptions[assistantIconIndex >= 0 ? assistantIconIndex : 0];
updatedTaggedAssistant(option);
}
}
if (!showSuggestions) {
@@ -195,7 +192,7 @@ export function ChatInputBar({
if (e.key === "ArrowDown") {
e.preventDefault();
setAssistantIconIndex((assistantIconIndex) =>
Math.min(assistantIconIndex + 1, filteredPersonas.length)
Math.min(assistantIconIndex + 1, assistantTagOptions.length)
);
} else if (e.key === "ArrowUp") {
e.preventDefault();
@@ -219,35 +216,36 @@ export function ChatInputBar({
mx-auto
"
>
{showSuggestions && filteredPersonas.length > 0 && (
{showSuggestions && assistantTagOptions.length > 0 && (
<div
ref={suggestionsRef}
className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full"
>
<div className="rounded-lg py-1.5 bg-background border border-border-medium shadow-lg mx-2 px-1.5 mt-2 rounded z-10">
{filteredPersonas.map((currentPersona, index) => (
{assistantTagOptions.map((currentAssistant, index) => (
<button
key={index}
className={`px-2 ${
assistantIconIndex == index && "bg-hover-lightish"
} rounded rounded-lg content-start flex gap-x-1 py-2 w-full hover:bg-hover-lightish cursor-pointer`}
onClick={() => {
updateCurrentPersona(currentPersona);
updatedTaggedAssistant(currentAssistant);
}}
>
<p className="font-bold">{currentPersona.name}</p>
<p className="font-bold">{currentAssistant.name}</p>
<p className="line-clamp-1">
{currentPersona.id == selectedAssistant.id &&
{currentAssistant.id == selectedAssistant.id &&
"(default) "}
{currentPersona.description}
{currentAssistant.description}
</p>
</button>
))}
<a
key={filteredPersonas.length}
key={assistantTagOptions.length}
target="_blank"
className={`${
assistantIconIndex == filteredPersonas.length && "bg-hover"
assistantIconIndex == assistantTagOptions.length &&
"bg-hover"
} rounded rounded-lg px-3 flex gap-x-1 py-2 w-full items-center hover:bg-hover-lightish cursor-pointer"`}
href="/assistants/new"
>
@@ -301,7 +299,7 @@ export function ChatInputBar({
<Hoverable
icon={FiX}
onClick={() => onSetSelectedAssistant(null)}
onClick={() => setAlternativeAssistant(null)}
/>
</div>
</div>
@@ -409,7 +407,7 @@ export function ChatInputBar({
removePadding
content={(close) => (
<AssistantsTab
availableAssistants={availableAssistants}
availableAssistants={assistantOptions}
llmProviders={llmProviders}
selectedAssistant={selectedAssistant}
onSelect={(assistant) => {

View File

@@ -505,7 +505,7 @@ export function removeMessage(
export function checkAnyAssistantHasSearch(
messageHistory: Message[],
availablePersonas: Persona[],
availableAssistants: Persona[],
livePersona: Persona
): boolean {
const response =
@@ -516,8 +516,8 @@ export function checkAnyAssistantHasSearch(
) {
return false;
}
const alternateAssistant = availablePersonas.find(
(persona) => persona.id === message.alternateAssistantID
const alternateAssistant = availableAssistants.find(
(assistant) => assistant.id === message.alternateAssistantID
);
return alternateAssistant
? personaIncludesRetrieval(alternateAssistant)

View File

@@ -58,7 +58,6 @@ import { ValidSources } from "@/lib/types";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { useMouseTracking } from "./hooks";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { getTitleFromDocument } from "@/lib/sources";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -445,7 +444,8 @@ export const AIMessage = ({
>
<Citation link={doc.link} index={ind + 1} />
<p className="shrink truncate ellipsis break-all ">
{getTitleFromDocument(doc)}
{doc.semantic_identifier ||
doc.document_id}
</p>
<div className="ml-auto flex-none">
{doc.is_internet ? (
@@ -798,7 +798,7 @@ export const HumanMessage = ({
!isEditing &&
(!files || files.length === 0)
) && "ml-auto"
} relative max-w-[70%] mb-auto rounded-3xl bg-user px-5 py-2.5`}
} relative max-w-[70%] mb-auto whitespace-break-spaces rounded-3xl bg-user px-5 py-2.5`}
>
{content}
</div>

View File

@@ -62,11 +62,10 @@ export function AssistantsTab({
toolName = "Image Generation";
toolIcon = <FiImage className="mr-1 my-auto" />;
}
return (
<Bubble key={tool.id} isSelected={false}>
<div className="flex flex-row gap-1">
{toolIcon}
<div className="flex line-wrap break-all flex-row gap-1">
<div className="flex-none my-auto">{toolIcon}</div>
{toolName}
</div>
</Bubble>

View File

@@ -1,180 +0,0 @@
"use client";
import React, { useEffect } from "react";
import { Modal } from "../../../../components/Modal";
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
import { FiltersTab } from "./FiltersTab";
import { FiCpu, FiFilter, FiX } from "react-icons/fi";
import { IconType } from "react-icons";
import { FaBrain } from "react-icons/fa";
import { AssistantsTab } from "./AssistantsTab";
import { Persona } from "@/app/admin/assistants/interfaces";
import { LlmTab } from "./LlmTab";
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
import { AssistantsIcon, IconProps } from "@/components/icons/icons";
const TabButton = ({
label,
icon: Icon,
isActive,
onClick,
}: {
label: string;
icon: IconType;
isActive: boolean;
onClick: () => void;
}) => (
<button
onClick={onClick}
className={`
pb-4
pt-6
px-2
text-emphasis
font-bold
${isActive ? "border-b-2 border-accent" : ""}
hover:bg-hover-light
hover:text-strong
transition
duration-200
ease-in-out
flex
`}
>
<Icon className="inline-block mr-2 my-auto" size="16" />
<p className="my-auto">{label}</p>
</button>
);
export function ConfigurationModal({
activeTab,
setActiveTab,
onClose,
availableAssistants,
selectedAssistant,
setSelectedAssistant,
filterManager,
llmProviders,
llmOverrideManager,
chatSessionId,
}: {
activeTab: string | null;
setActiveTab: (tab: string | null) => void;
onClose: () => void;
availableAssistants: Persona[];
selectedAssistant: Persona;
setSelectedAssistant: (assistant: Persona) => void;
filterManager: FilterManager;
llmProviders: LLMProviderDescriptor[];
llmOverrideManager: LlmOverrideManager;
chatSessionId?: number;
}) {
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === "Escape") {
onClose();
}
};
document.addEventListener("keydown", handleKeyDown);
return () => {
document.removeEventListener("keydown", handleKeyDown);
};
}, [onClose]);
if (!activeTab) return null;
return (
<Modal
onOutsideClick={onClose}
noPadding
className="
w-4/6
h-4/6
flex
flex-col
"
>
<div className="rounded flex flex-col overflow-hidden">
<div className="mb-4">
<div className="flex border-b border-border bg-background-emphasis">
<div className="flex px-6 gap-x-2">
<TabButton
label="Assistants"
icon={FaBrain}
isActive={activeTab === "assistants"}
onClick={() => setActiveTab("assistants")}
/>
<TabButton
label="Models"
icon={FiCpu}
isActive={activeTab === "llms"}
onClick={() => setActiveTab("llms")}
/>
<TabButton
label="Filters"
icon={FiFilter}
isActive={activeTab === "filters"}
onClick={() => setActiveTab("filters")}
/>
</div>
<button
className="
ml-auto
px-1
py-1
text-xs
font-medium
rounded
hover:bg-hover
focus:outline-none
focus:ring-2
focus:ring-offset-2
focus:ring-subtle
flex
items-center
h-fit
my-auto
mr-5
"
onClick={onClose}
>
<FiX size={24} />
</button>
</div>
</div>
<div className="flex flex-col overflow-y-auto">
<div className="px-8 pt-4">
{activeTab === "filters" && (
<FiltersTab filterManager={filterManager} />
)}
{activeTab === "llms" && (
<LlmTab
chatSessionId={chatSessionId}
llmOverrideManager={llmOverrideManager}
currentAssistant={selectedAssistant}
/>
)}
{activeTab === "assistants" && (
<div>
<AssistantsTab
availableAssistants={availableAssistants}
llmProviders={llmProviders}
selectedAssistant={selectedAssistant}
onSelect={(assistant) => {
setSelectedAssistant(assistant);
onClose();
}}
/>
</div>
)}
</div>
</div>
</div>
</Modal>
);
}

View File

@@ -35,7 +35,7 @@ export default async function Page({
folders,
toggleSidebar,
openedFolders,
defaultPersonaId,
defaultAssistantId,
finalDocumentSidebarInitialWidth,
shouldShowWelcomeModal,
shouldDisplaySourcesIncompleteModal,
@@ -58,7 +58,7 @@ export default async function Page({
chatSessions,
availableSources,
availableDocumentSets: documentSets,
availablePersonas: assistants,
availableAssistants: assistants,
availableTags: tags,
llmProviders,
folders,
@@ -66,7 +66,7 @@ export default async function Page({
}}
>
<WrappedChat
defaultPersonaId={defaultPersonaId}
defaultAssistantId={defaultAssistantId}
initiallyToggled={toggleSidebar}
/>
</ChatProvider>

View File

@@ -1,111 +0,0 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { BasicSelectable } from "@/components/BasicClickable";
import { AssistantsIcon } from "@/components/icons/icons";
import { User } from "@/lib/types";
import { Text } from "@tremor/react";
import Link from "next/link";
import { FaRobot } from "react-icons/fa";
import { FiEdit2 } from "react-icons/fi";
function AssistantDisplay({
persona,
onSelect,
user,
}: {
persona: Persona;
onSelect: (persona: Persona) => void;
user: User | null;
}) {
const isEditable =
(!user || user.id === persona.owner?.id) &&
!persona.default_persona &&
(!persona.is_public || !user || user.role === "admin");
return (
<div className="flex">
<div className="w-full" onClick={() => onSelect(persona)}>
<BasicSelectable selected={false} fullWidth>
<div className="flex">
<div className="truncate w-48 3xl:w-56 flex">
<AssistantsIcon className="mr-2 my-auto" size={16} />{" "}
{persona.name}
</div>
</div>
</BasicSelectable>
</div>
{isEditable && (
<div className="pl-2 my-auto">
<Link href={`/assistants/edit/${persona.id}`}>
<FiEdit2
className="my-auto ml-auto hover:bg-hover p-0.5"
size={20}
/>
</Link>
</div>
)}
</div>
);
}
export function AssistantsTab({
personas,
onPersonaChange,
user,
}: {
personas: Persona[];
onPersonaChange: (persona: Persona | null) => void;
user: User | null;
}) {
const globalAssistants = personas.filter((persona) => persona.is_public);
const personalAssistants = personas.filter(
(persona) =>
(!user || persona.users.some((u) => u.id === user.id)) &&
!persona.is_public
);
return (
<div className="mt-4 pb-1 overflow-y-auto h-full flex flex-col gap-y-1">
<Text className="mx-3 text-xs mb-4">
Select an Assistant below to begin a new chat with them!
</Text>
<div className="mx-3">
{globalAssistants.length > 0 && (
<>
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 font-bold">
Global
</div>
{globalAssistants.map((persona) => {
return (
<AssistantDisplay
key={persona.id}
persona={persona}
onSelect={onPersonaChange}
user={user}
/>
);
})}
</>
)}
{personalAssistants.length > 0 && (
<>
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 mt-5 font-bold">
Personal
</div>
{personalAssistants.map((persona) => {
return (
<AssistantDisplay
key={persona.id}
persona={persona}
onSelect={onPersonaChange}
user={user}
/>
);
})}
</>
)}
</div>
</div>
);
}

View File

@@ -99,22 +99,28 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
h-screen
transition-transform`}
>
<div className="ml-4 mr-3 flex flex gap-x-1 items-center mt-2 my-auto text-text-700 text-xl">
<div className="mr-1 my-auto h-6 w-6">
<div className="max-w-full ml-3 mr-3 mt-2 flex flex gap-x-1 items-center my-auto text-text-700 text-xl">
<div className="mr-1 mb-auto h-6 w-6">
<Logo height={24} width={24} />
</div>
<div className="invisible">
{enterpriseSettings && enterpriseSettings.application_name ? (
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
<div>
<HeaderTitle>
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Danswer</p>
)}
</div>
) : (
<HeaderTitle>Danswer</HeaderTitle>
)}
</div>
{toggleSidebar && (
<Tooltip delayDuration={1000} content={`${commandSymbol}E show`}>
<button className="ml-auto" onClick={toggleSidebar}>
<button className="mb-auto ml-auto" onClick={toggleSidebar}>
{!toggled ? <RightToLineIcon /> : <LefToLineIcon />}
</button>
</Tooltip>

View File

@@ -3,6 +3,7 @@
import { HeaderTitle } from "@/components/header/Header";
import { Logo } from "@/components/Logo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import { useContext } from "react";
export default function FixedLogo() {
@@ -11,17 +12,24 @@ export default function FixedLogo() {
const enterpriseSettings = combinedSettings?.enterpriseSettings;
return (
<div className="fixed flex z-40 left-4 top-2">
{" "}
<a href="/chat" className="ml-7 text-text-700 text-xl">
<div>
<div className="absolute flex z-40 left-2.5 top-2">
<div className="max-w-[200px] flex gap-x-1 my-auto">
<div className="flex-none invisible mb-auto">
<Logo />
</div>
<div className="">
{enterpriseSettings && enterpriseSettings.application_name ? (
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
<div>
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Danswer</p>
)}
</div>
) : (
<HeaderTitle>Danswer</HeaderTitle>
)}
</div>
</a>
</div>
</div>
);
}

View File

@@ -10,12 +10,11 @@ const ToggleSwitch = () => {
const commandSymbol = KeyboardSymbol();
const pathname = usePathname();
const router = useRouter();
const [activeTab, setActiveTab] = useState(() => {
if (typeof window !== "undefined") {
return localStorage.getItem("activeTab") || "chat";
}
return "chat";
return pathname == "/search" ? "search" : "chat";
});
const [isInitialLoad, setIsInitialLoad] = useState(true);
useEffect(() => {

View File

@@ -168,7 +168,8 @@ export default async function Home() {
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
) &&
!shouldDisplayNoSourcesModal &&
!shouldShowWelcomeModal;
!shouldShowWelcomeModal &&
(!user || user.role == "admin");
const sidebarToggled = cookies().get(SIDEBAR_TOGGLED_COOKIE_NAME);
const agenticSearchToggle = cookies().get(AGENTIC_SEARCH_TYPE_COOKIE_NAME);
@@ -178,7 +179,7 @@ export default async function Home() {
: false;
const agenticSearchEnabled = agenticSearchToggle
? agenticSearchToggle.value.toLocaleLowerCase() == "true" || true
? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false
: false;
return (

View File

@@ -0,0 +1,26 @@
import React from "react";
import { Button } from "@tremor/react";
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
interface AdvancedOptionsToggleProps {
showAdvancedOptions: boolean;
setShowAdvancedOptions: (show: boolean) => void;
}
export function AdvancedOptionsToggle({
showAdvancedOptions,
setShowAdvancedOptions,
}: AdvancedOptionsToggleProps) {
return (
<Button
type="button"
variant="light"
size="xs"
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
className="mb-4 text-xs text-text-500 hover:text-text-400"
>
Advanced Options
</Button>
);
}

View File

@@ -31,7 +31,10 @@ export function Logo({
}
return (
<div style={{ height, width }} className={`relative ${className}`}>
<div
style={{ height, width }}
className={`flex-none relative ${className}`}
>
{/* TODO: figure out how to use Next Image here */}
<img
src="/api/enterprise-settings/logo"

View File

@@ -91,7 +91,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
return (
<div className="h-screen overflow-y-hidden">
<div className="flex h-full">
<div className="w-64 z-20 bg-background-100 pt-4 pb-8 h-full border-r border-border miniscroll overflow-auto">
<div className="w-64 z-20 bg-background-100 pt-3 pb-8 h-full border-r border-border miniscroll overflow-auto">
<AdminSidebar
collections={[
{

View File

@@ -28,9 +28,9 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
return (
<aside className="pl-0">
<nav className="space-y-2 pl-4">
<div className="pb-12 flex">
<div className="fixed left-0 top-0 py-2 pl-4 bg-background-100 w-[200px]">
<nav className="space-y-2 pl-2">
<div className="mb-4 flex">
<div className="bg-background-100">
<Link
className="flex flex-col"
href={
@@ -39,8 +39,8 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
: "/search"
}
>
<div className="flex gap-x-1 my-auto">
<div className="my-auto">
<div className="max-w-[200px] flex gap-x-1 my-auto">
<div className="flex-none mb-auto">
<Logo />
</div>
<div className="my-auto">
@@ -50,7 +50,7 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle -mt-1.5">
<p className="text-xs text-subtle">
Powered by Danswer
</p>
)}
@@ -63,14 +63,16 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
</Link>
</div>
</div>
<Link href={"/chat"}>
<button className="text-sm block w-48 py-2.5 flex px-2 text-left bg-background-200 hover:bg-background-200/80 cursor-pointer rounded">
<BackIcon size={20} className="text-neutral" />
<p className="ml-1">Back to Danswer</p>
</button>
</Link>
<div className="px-3">
<Link href={"/chat"}>
<button className="text-sm block w-48 py-2.5 flex px-2 text-left bg-background-200 hover:bg-background-200/80 cursor-pointer rounded">
<BackIcon size={20} className="text-neutral" />
<p className="ml-1">Back to Danswer</p>
</button>
</Link>
</div>
{collections.map((collection, collectionInd) => (
<div key={collectionInd}>
<div className="px-3" key={collectionInd}>
<h2 className="text-xs text-strong font-bold pb-2">
<div>{collection.name}</div>
</h2>

View File

@@ -59,18 +59,23 @@ export default function FunctionalHeader({
<div className="pb-6 left-0 sticky top-0 z-10 w-full relative flex">
<div className="mt-2 mx-4 text-text-700 flex w-full">
<div className="absolute z-[100] my-auto flex items-center text-xl font-bold">
<FiSidebar size={20} />
<div className="ml-2 text-text-700 text-xl">
{enterpriseSettings && enterpriseSettings.application_name ? (
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
) : (
<HeaderTitle>Danswer</HeaderTitle>
)}
<div className="pt-[2px] mb-auto">
<FiSidebar size={20} />
</div>
<div className="break-words inline-block w-fit ml-2 text-text-700 text-xl">
<div className="max-w-[200px]">
{enterpriseSettings && enterpriseSettings.application_name ? (
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
) : (
<HeaderTitle>Danswer</HeaderTitle>
)}
</div>
</div>
{page == "chat" && (
<Tooltip delayDuration={1000} content={`${commandSymbol}U`}>
<Link
className="mb-auto pt-[2px]"
href={
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
@@ -79,10 +84,9 @@ export default function FunctionalHeader({
: "")
}
>
<NewChatIcon
size={20}
className="ml-2 my-auto cursor-pointer text-text-700 hover:text-text-600 transition-colors duration-300"
/>
<div className="cursor-pointer ml-2 flex-none text-text-700 hover:text-text-600 transition-colors duration-300">
<NewChatIcon size={20} className="" />
</div>
</Link>
</Tooltip>
)}

View File

@@ -18,7 +18,7 @@ interface ChatContextProps {
chatSessions: ChatSession[];
availableSources: ValidSources[];
availableDocumentSets: DocumentSet[];
availablePersonas: Persona[];
availableAssistants: Persona[];
availableTags: Tag[];
llmProviders: LLMProviderDescriptor[];
folders: Folder[];

View File

@@ -11,7 +11,11 @@ import { Logo } from "../Logo";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
export function HeaderTitle({ children }: { children: JSX.Element | string }) {
return <h1 className="flex text-2xl text-strong font-bold">{children}</h1>;
return (
<h1 className="flex text-2xl text-strong leading-none font-bold">
{children}
</h1>
);
}
interface HeaderProps {
@@ -36,8 +40,8 @@ export function Header({ user, page }: HeaderProps) {
settings && settings.default_page === "chat" ? "/chat" : "/search"
}
>
<div className="flex my-auto">
<div className="mr-1 my-auto">
<div className="max-w-[200px] bg-black flex my-auto">
<div className="mr-1 mb-auto">
<Logo />
</div>
<div className="my-auto">
@@ -47,9 +51,7 @@ export function Header({ user, page }: HeaderProps) {
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle -mt-1.5">
Powered by Danswer
</p>
<p className="text-xs text-subtle">Powered by Danswer</p>
)}
</div>
) : (

View File

@@ -19,7 +19,7 @@ export function Citation({
return (
<CustomTooltip
citation
content={<p className="inline-block p-0 m-0 truncate">{link}</p>}
content={<div className="inline-block p-0 m-0 truncate">{link}</div>}
>
<a
onClick={() => (link ? window.open(link, "_blank") : undefined)}

View File

@@ -30,7 +30,6 @@ import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
interface FetchChatDataResult {
user: User | null;
chatSessions: ChatSession[];
ccPairs: CCPairBasicInfo[];
availableSources: ValidSources[];
documentSets: DocumentSet[];
@@ -39,7 +38,7 @@ interface FetchChatDataResult {
llmProviders: LLMProviderDescriptor[];
folders: Folder[];
openedFolders: Record<string, boolean>;
defaultPersonaId?: number;
defaultAssistantId?: number;
toggleSidebar: boolean;
finalDocumentSidebarInitialWidth?: number;
shouldShowWelcomeModal: boolean;
@@ -150,9 +149,9 @@ export async function fetchChatData(searchParams: {
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
}
const defaultPersonaIdRaw = searchParams["assistantId"];
const defaultPersonaId = defaultPersonaIdRaw
? parseInt(defaultPersonaIdRaw)
const defaultAssistantIdRaw = searchParams["assistantId"];
const defaultAssistantId = defaultAssistantIdRaw
? parseInt(defaultAssistantIdRaw)
: undefined;
const documentSidebarCookieInitialWidth = cookies().get(
@@ -178,7 +177,8 @@ export async function fetchChatData(searchParams: {
!shouldShowWelcomeModal &&
!ccPairs.some(
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
);
) &&
(!user || user.role == "admin");
// if no connectors are setup, only show personas that are pure
// passthrough and don't do any retrieval
@@ -209,7 +209,7 @@ export async function fetchChatData(searchParams: {
llmProviders,
folders,
openedFolders,
defaultPersonaId,
defaultAssistantId,
finalDocumentSidebarInitialWidth,
toggleSidebar,
shouldShowWelcomeModal,

View File

@@ -301,9 +301,3 @@ function stripTrailingSlash(str: string) {
}
return str;
}
export const getTitleFromDocument = (document: DanswerDocument) => {
return stripTrailingSlash(document.document_id).split("/")[
stripTrailingSlash(document.document_id).split("/").length - 1
];
};