mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 16:25:45 +00:00
Compare commits
27 Commits
eval/split
...
0.4.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48a0d29a5c | ||
|
|
6ff8e6c0ea | ||
|
|
2470c68506 | ||
|
|
866bc803b1 | ||
|
|
9c6084bd0d | ||
|
|
a0b46c60c6 | ||
|
|
4029233df0 | ||
|
|
6c88c0156c | ||
|
|
33332d08f2 | ||
|
|
17005fb705 | ||
|
|
48a7fe80b1 | ||
|
|
1276732409 | ||
|
|
f91b92a898 | ||
|
|
6222f533be | ||
|
|
1b49d17239 | ||
|
|
2f5f19642e | ||
|
|
6db4634871 | ||
|
|
5cfed45cef | ||
|
|
581ffde35a | ||
|
|
6313e6d91d | ||
|
|
c09c94bf32 | ||
|
|
0e8ba111c8 | ||
|
|
2ba24b1734 | ||
|
|
44820b4909 | ||
|
|
eb3e7610fc | ||
|
|
7fbbb174bb | ||
|
|
3854ca11af |
@@ -0,0 +1,41 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
||||
@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -80,7 +80,7 @@ def should_prune_cc_pair(
|
||||
return True
|
||||
return False
|
||||
|
||||
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
@@ -89,11 +89,9 @@ def should_prune_cc_pair(
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
logger.info("Another Connector is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
|
||||
@@ -33,7 +33,7 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
@@ -51,7 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
@@ -187,37 +187,46 @@ def _handle_internet_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
# If files are already provided, don't run the search tool
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
else None
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
return None
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
if (
|
||||
new_msg_req.query_override
|
||||
or (
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
)
|
||||
or new_msg_req.search_doc_ids
|
||||
or DISABLE_LLM_CHOOSE_SEARCH
|
||||
):
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override
|
||||
else None
|
||||
)
|
||||
# if we are using selected docs, just put something here so the Tool doesn't need
|
||||
# to build its own args via an LLM call
|
||||
if new_msg_req.search_doc_ids:
|
||||
args = {"query": new_msg_req.message}
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
@@ -253,7 +262,6 @@ def stream_chat_message_objects(
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -361,6 +369,14 @@ def stream_chat_message_objects(
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
@@ -576,11 +592,7 @@ def stream_chat_message_objects(
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=(
|
||||
_check_should_force_search(new_msg_req)
|
||||
if search_tool and len(tools) == 1
|
||||
else None
|
||||
),
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
|
||||
@@ -214,8 +214,8 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
|
||||
@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
|
||||
continue
|
||||
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments]
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if jira.fields.priority:
|
||||
metadata_dict["priority"] = jira.fields.priority.name
|
||||
if jira.fields.status:
|
||||
metadata_dict["status"] = jira.fields.status.name
|
||||
if jira.fields.resolution:
|
||||
metadata_dict["resolution"] = jira.fields.resolution.name
|
||||
if jira.fields.labels:
|
||||
metadata_dict["label"] = jira.fields.labels
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
|
||||
@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
@@ -123,6 +123,8 @@ class DocumentBase(BaseModel):
|
||||
for char in replace_chars:
|
||||
title = title.replace(char, " ")
|
||||
title = title.strip()
|
||||
# Title could be quite long here as there is no truncation done
|
||||
# just prior to embedding, it could be truncated
|
||||
return title
|
||||
|
||||
def get_metadata_str_attributes(self) -> list[str] | None:
|
||||
|
||||
@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
|
||||
@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,41 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Delete existing relationships
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
# Add new relationships from given group_ids
|
||||
if group_ids:
|
||||
new_relationships = [
|
||||
LLMProvider__UserGroup(
|
||||
llm_provider_id=llm_provider_id,
|
||||
user_group_id=group_id,
|
||||
)
|
||||
for group_id in group_ids
|
||||
]
|
||||
db_session.add_all(new_relationships)
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
@@ -36,36 +62,35 @@ def upsert_llm_provider(
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider:
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
provider=llm_provider.provider,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
default_model_name=llm_provider.default_model_name,
|
||||
fast_default_model_name=llm_provider.fast_default_model_name,
|
||||
model_names=llm_provider.model_names,
|
||||
is_default_provider=None,
|
||||
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
existing_llm_provider.is_public = llm_provider.is_public
|
||||
|
||||
if not existing_llm_provider.id:
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(llm_provider_model)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
@@ -74,8 +99,29 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
if not user:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
user_groups_subquery = (
|
||||
select(User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id == user.id)
|
||||
.subquery()
|
||||
)
|
||||
access_conditions = or_(
|
||||
LLMProviderModel.is_public,
|
||||
LLMProviderModel.id.in_( # User is part of a group that has access
|
||||
select(LLMProvider__UserGroup.llm_provider_id).where(
|
||||
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
|
||||
)
|
||||
),
|
||||
)
|
||||
stmt = stmt.where(access_conditions)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
@@ -119,6 +165,13 @@ def remove_embedding_provider(
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
# Remove LLMProvider's dependent relationships
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||
)
|
||||
)
|
||||
# Remove LLMProvider
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
|
||||
@@ -932,6 +932,13 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="llm_provider__user_group",
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
@@ -1109,7 +1116,6 @@ class Persona(Base):
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
# These are only defaults, users can select from all if desired
|
||||
prompts: Mapped[list[Prompt]] = relationship(
|
||||
@@ -1137,6 +1143,7 @@ class Persona(Base):
|
||||
viewonly=True,
|
||||
)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="persona__user_group",
|
||||
@@ -1360,6 +1367,17 @@ class Persona__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider__UserGroup(Base):
|
||||
__tablename__ = "llm_provider__user_group"
|
||||
|
||||
llm_provider_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("llm_provider.id"), primary_key=True
|
||||
)
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__UserGroup(Base):
|
||||
__tablename__ = "document_set__user_group"
|
||||
|
||||
|
||||
@@ -331,12 +331,18 @@ def _index_vespa_chunk(
|
||||
document = chunk.source_document
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
embeddings = chunk.embeddings
|
||||
|
||||
if chunk.embeddings.full_embedding is None:
|
||||
embeddings.full_embedding = chunk.title_embedding
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
if m_c_embed is None:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
|
||||
else:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=default_fast_model_name,
|
||||
model_names=None,
|
||||
is_public=True,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
||||
@@ -15,7 +15,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from abc import abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
@@ -14,8 +13,7 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
@@ -66,21 +64,21 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
)
|
||||
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[IndexChunk]:
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, list[float]] = {}
|
||||
title_embed_dict: dict[str, list[float] | None] = {}
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunk_texts = []
|
||||
chunk_texts: list[str] = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
@@ -92,22 +90,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
|
||||
# Batching for embedding
|
||||
text_batches = batch_list(chunk_texts, batch_size)
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
len_text_batches = len(text_batches)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(
|
||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||
)
|
||||
|
||||
# Replace line above with the line below for easy debugging of indexing flow
|
||||
# skipping the actual model
|
||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||
embeddings = self.embedding_model.encode(
|
||||
chunk_texts, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
|
||||
chunk_titles = {
|
||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||
@@ -116,16 +101,15 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# Drop any None or empty strings
|
||||
chunk_titles_list = [title for title in chunk_titles if title]
|
||||
|
||||
# Embed Titles in batches
|
||||
title_batches = batch_list(chunk_titles_list, batch_size)
|
||||
len_title_batches = len(title_batches)
|
||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||
if chunk_titles_list:
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
title_batch, text_type=EmbedTextType.PASSAGE
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||
{
|
||||
title: vector
|
||||
for title, vector in zip(chunk_titles_list, title_embeddings)
|
||||
}
|
||||
)
|
||||
|
||||
# Mapping embeddings to chunks
|
||||
@@ -184,4 +168,6 @@ def get_embedding_model_from_db_embedding_model(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
api_key=db_embedding_model.api_key,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
Embedding = list[float] | None
|
||||
|
||||
|
||||
class ChunkEmbedding(BaseModel):
|
||||
|
||||
@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
@@ -99,6 +99,7 @@ class Answer:
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
@@ -107,10 +108,8 @@ class Answer:
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# if specified, tells the LLM to always this tool
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
# if set to True, then never use the LLMs provided tool-calling functonality
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
@@ -129,6 +128,7 @@ class Answer:
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
@@ -187,7 +187,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool and self.force_use_tool.args is not None:
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
@@ -221,7 +221,7 @@ class Answer:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
@@ -245,7 +245,8 @@ class Answer:
|
||||
][0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool and self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
@@ -303,7 +304,7 @@ class Answer:
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
|
||||
@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
|
||||
@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
|
||||
@@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
|
||||
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
|
||||
),
|
||||
model_names=model_names,
|
||||
is_public=True,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -16,10 +15,8 @@ from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
@@ -28,7 +25,6 @@ from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
@@ -37,53 +33,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LLM_TOKENIZER: Any = None
|
||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Encoding:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
|
||||
return _LLM_TOKENIZER
|
||||
|
||||
|
||||
def get_default_llm_token_encode() -> Callable[[str], Any]:
|
||||
global _LLM_TOKENIZER_ENCODE
|
||||
if _LLM_TOKENIZER_ENCODE is None:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
if isinstance(tokenizer, Encoding):
|
||||
return tokenizer.encode # type: ignore
|
||||
|
||||
# Currently only supports OpenAI encoder
|
||||
raise ValueError("Invalid Encoder selected")
|
||||
|
||||
return _LLM_TOKENIZER_ENCODE
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: Encoding
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
) -> list[InferenceChunk]:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
|
||||
@@ -50,8 +50,8 @@ from danswer.db.standard_answer import create_initial_default_standard_answer_ca
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
from httpx import HTTPError
|
||||
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.natural_language_processing.utils import get_default_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@@ -20,50 +19,13 @@ from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||
|
||||
|
||||
def clean_model_name(model_str: str) -> str:
|
||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
return _TOKENIZER[0]
|
||||
|
||||
|
||||
def build_model_server_url(
|
||||
model_server_host: str,
|
||||
model_server_port: int,
|
||||
@@ -91,6 +53,7 @@ class EmbeddingModel:
|
||||
provider_type: str | None,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
retrim_content: bool = False,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -99,32 +62,90 @@ class EmbeddingModel:
|
||||
self.passage_prefix = passage_prefix
|
||||
self.normalize = normalize
|
||||
self.model_name = model_name
|
||||
self.retrim_content = retrim_content
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
||||
prefixed_texts = [self.query_prefix + text for text in texts]
|
||||
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
|
||||
prefixed_texts = [self.passage_prefix + text for text in texts]
|
||||
else:
|
||||
prefixed_texts = texts
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
) -> list[list[float] | None]:
|
||||
if not texts:
|
||||
logger.warning("No texts to be embedded")
|
||||
return []
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
)
|
||||
if self.retrim_content:
|
||||
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
|
||||
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
|
||||
# However this slight miscounting is very unlikely to have any material impact.
|
||||
texts = [
|
||||
tokenizer_trim_content(
|
||||
content=text,
|
||||
desired_length=self.max_seq_length,
|
||||
tokenizer=get_default_tokenizer(),
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
response.raise_for_status()
|
||||
if self.provider_type:
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
EmbedResponse(**response.json()).embeddings
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[list[float] | None] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
@@ -136,7 +157,7 @@ class CrossEncoderEnsembleModel:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
|
||||
response = requests.post(
|
||||
@@ -199,7 +220,7 @@ def warm_up_encoders(
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
# retry the whole server
|
||||
wait_time = 5
|
||||
for attempt in range(20):
|
||||
for _ in range(20):
|
||||
try:
|
||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
return
|
||||
100
backend/danswer/natural_language_processing/utils.py
Normal file
100
backend/danswer/natural_language_processing/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.core import Encoding
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
|
||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||
_LLM_TOKENIZER: Any = None
|
||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
return _TOKENIZER[0]
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Encoding:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
|
||||
return _LLM_TOKENIZER
|
||||
|
||||
|
||||
def get_default_llm_token_encode() -> Callable[[str], Any]:
|
||||
global _LLM_TOKENIZER_ENCODE
|
||||
if _LLM_TOKENIZER_ENCODE is None:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
if isinstance(tokenizer, Encoding):
|
||||
return tokenizer.encode # type: ignore
|
||||
|
||||
# Currently only supports OpenAI encoder
|
||||
raise ValueError("Invalid Encoder selected")
|
||||
|
||||
return _LLM_TOKENIZER_ENCODE
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: Encoding
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
) -> list[InferenceChunk]:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
@@ -34,7 +34,7 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
@@ -206,6 +206,7 @@ def stream_answer_objects(
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
),
|
||||
@@ -256,6 +257,9 @@ def stream_answer_objects(
|
||||
)
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
|
||||
@@ -267,9 +271,12 @@ def stream_answer_objects(
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
if query_req.skip_gen_ai_answer_generation:
|
||||
# Exit early if only source docs + contexts are requested
|
||||
# Putting exit here assumes that a packet with the ID
|
||||
# SECTION_RELEVANCE_LIST_ID is the last one yielded before
|
||||
# calling the LLM
|
||||
return
|
||||
|
||||
elif packet.id == SEARCH_EVALUATION_ID:
|
||||
evaluation_response = LLMRelevanceSummaryResponse(
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -154,6 +155,7 @@ class SearchPipeline:
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
|
||||
@@ -173,9 +175,11 @@ class SearchPipeline:
|
||||
expanded_inference_sections = []
|
||||
|
||||
# Full doc setting takes priority
|
||||
|
||||
if self.search_query.full_doc:
|
||||
seen_document_ids = set()
|
||||
unique_chunks = []
|
||||
|
||||
# This preserves the ordering since the chunks are retrieved in score order
|
||||
for chunk in retrieved_chunks:
|
||||
if chunk.document_id not in seen_document_ids:
|
||||
@@ -195,7 +199,6 @@ class SearchPipeline:
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
@@ -240,32 +243,35 @@ class SearchPipeline:
|
||||
merged_ranges = [
|
||||
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
|
||||
]
|
||||
flat_ranges = [r for ranges in merged_ranges for r in ranges]
|
||||
|
||||
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
|
||||
flattened_inference_chunks: list[InferenceChunk] = []
|
||||
parallel_functions_with_args = []
|
||||
|
||||
for chunk_range in flat_ranges:
|
||||
functions_with_args.append(
|
||||
(
|
||||
# If Large Chunks are introduced, additional filters need to be added here
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
# Only need the document_id here, just use any chunk in the range is fine
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
# There is no chunk level permissioning, this expansion around chunks
|
||||
# can be assumed to be safe
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
)
|
||||
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
|
||||
if above == below == 0:
|
||||
flattened_inference_chunks.extend(chunk_range.chunks)
|
||||
|
||||
# list of list of inference chunks where the inner list needs to be combined for content
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
flattened_inference_chunks = [
|
||||
chunk for sublist in list_inference_chunks for chunk in sublist
|
||||
]
|
||||
else:
|
||||
parallel_functions_with_args.append(
|
||||
(
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if parallel_functions_with_args:
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
parallel_functions_with_args, allow_failures=False
|
||||
)
|
||||
for inference_chunks in list_inference_chunks:
|
||||
flattened_inference_chunks.extend(inference_chunks)
|
||||
|
||||
doc_chunk_ind_to_chunk = {
|
||||
(chunk.document_id, chunk.chunk_id): chunk
|
||||
|
||||
@@ -12,6 +12,9 @@ from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.natural_language_processing.search_nlp_models import (
|
||||
CrossEncoderEnsembleModel,
|
||||
)
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
@@ -20,7 +23,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
|
||||
from danswer.natural_language_processing.search_nlp_models import IntentModel
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.search.search_nlp_models import IntentModel
|
||||
from danswer.server.query_and_chat.models import HelperResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import nltk # type:ignore
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
@@ -11,6 +12,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
@@ -20,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -143,7 +144,9 @@ def doc_index_retrieval(
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
@@ -152,7 +155,9 @@ def doc_index_retrieval(
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
|
||||
@@ -94,7 +94,7 @@ def history_based_query_rephrase(
|
||||
llm: LLM,
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
skip_first_rephrase: bool = False,
|
||||
skip_first_rephrase: bool = True,
|
||||
prompt_template: str = HISTORY_QUERY_REPHRASE,
|
||||
) -> str:
|
||||
# Globally disabled, just use the exact user query
|
||||
|
||||
@@ -96,6 +96,8 @@ def upsert_ingestion_doc(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -132,6 +134,8 @@ def upsert_ingestion_doc(
|
||||
normalize=sec_db_embedding_model.normalize,
|
||||
query_prefix=sec_db_embedding_model.query_prefix,
|
||||
passage_prefix=sec_db_embedding_model.passage_prefix,
|
||||
api_key=sec_db_embedding_model.api_key,
|
||||
provider_type=sec_db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
|
||||
@@ -9,7 +9,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.server.documents.models import ChunkInfo
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
@@ -42,7 +42,7 @@ def test_embedding_configuration(
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||
|
||||
@@ -147,10 +147,10 @@ def set_provider_as_default(
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
_: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
LLMProviderDescriptor.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
|
||||
]
|
||||
|
||||
@@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
|
||||
custom_config: dict[str, str] | None
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_public: bool = True
|
||||
groups: list[int] | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
is_public=llm_provider_model.is_public,
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
)
|
||||
|
||||
@@ -45,7 +45,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.headers import get_litellm_additional_request_headers
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
|
||||
@@ -90,7 +90,7 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
parent_message_id: int | None
|
||||
# New message contents
|
||||
message: str
|
||||
# file's that we should attach to this message
|
||||
# Files that we should attach to this message
|
||||
file_descriptors: list[FileDescriptor]
|
||||
# If no prompt provided, uses the largest prompt of the chat session
|
||||
# but really this should be explicitly specified, only in the simplified APIs is this inferred
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
class ForceUseTool(BaseModel):
|
||||
# Could be not a forced usage of the tool but still have args, in which case
|
||||
# if the tool is called, then those args are applied instead of what the LLM
|
||||
# wanted to call it with
|
||||
force_use: bool
|
||||
tool_name: str
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
@@ -16,25 +18,10 @@ class ForceUseTool(BaseModel):
|
||||
return {"type": "function", "function": {"name": self.tool_name}}
|
||||
|
||||
|
||||
def modify_message_chain_for_force_use_tool(
|
||||
messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None
|
||||
) -> list[BaseMessage]:
|
||||
"""NOTE: modifies `messages` in place."""
|
||||
if not force_use_tool:
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_call["args"] = force_use_tool.args or {}
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def filter_tools_for_force_tool_use(
|
||||
tools: list[Tool], force_use_tool: ForceUseTool | None = None
|
||||
tools: list[Tool], force_use_tool: ForceUseTool
|
||||
) -> list[Tool]:
|
||||
if not force_use_tool:
|
||||
if not force_use_tool.force_use:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name == force_use_tool.tool_name]
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
|
||||
|
||||
def build_tool_message(
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
|
||||
from tiktoken import Encoding
|
||||
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
@@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
@@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
|
||||
|
||||
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
@@ -40,110 +41,133 @@ router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
|
||||
def _initialize_client(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
) -> Any:
|
||||
if provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(api_key)
|
||||
)
|
||||
project_id = json.loads(api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
||||
self.api_key = api_key
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: str,
|
||||
# Only for Google as is needed on client setup
|
||||
self.model = model
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = self._initialize_client()
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _initialize_client(self) -> Any:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(
|
||||
self.model or DEFAULT_VERTEX_MODEL
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def encode(
|
||||
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
||||
) -> list[list[float]]:
|
||||
return [
|
||||
self.embed(text=text, text_type=text_type, model=model_name)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
def embed(
|
||||
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
||||
) -> list[float]:
|
||||
logger.debug(f"Embedding text with provider: {self.provider}")
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||
def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
response = self.client.embeddings.create(input=text, model=model)
|
||||
return response.data[0].embedding
|
||||
# OpenAI does not seem to provide truncation option, however
|
||||
# the context lengths used by Danswer currently are smaller than the max token length
|
||||
# for OpenAI embeddings so it's not a big deal
|
||||
response = self.client.embeddings.create(input=texts, model=model)
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
|
||||
def _embed_cohere(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = self.client.embed(
|
||||
texts=[text],
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
return response.embeddings[0]
|
||||
return response.embeddings
|
||||
|
||||
def _embed_voyage(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
response = self.client.embed(text, model=model, input_type=embedding_type)
|
||||
return response.embeddings[0]
|
||||
# Similar to Cohere, the API server will do approximate size chunking
|
||||
# it's acceptable to miss by a few tokens
|
||||
response = self.client.embed(
|
||||
texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True, # Also this is default
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
def _embed_vertex(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
embedding = self.client.get_embeddings(
|
||||
embeddings = self.client.get_embeddings(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
]
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # Also this is default
|
||||
)
|
||||
return embedding[0].values
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
) -> list[list[float] | None]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -212,34 +236,83 @@ def embed_text(
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
) -> list[list[float]]:
|
||||
if provider_type is not None:
|
||||
prefix: str | None,
|
||||
) -> list[list[float] | None]:
|
||||
non_empty_texts = []
|
||||
empty_indices = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
if text.strip():
|
||||
non_empty_texts.append(text)
|
||||
else:
|
||||
empty_indices.append(idx)
|
||||
|
||||
# Third party API based embedding model
|
||||
if not non_empty_texts:
|
||||
embeddings = []
|
||||
elif provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
if api_key is None:
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
if prefix:
|
||||
# This may change in the future if some providers require the user
|
||||
# to manually append a prefix but this is not the case currently
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.encode(texts, model_name, text_type)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=non_empty_texts,
|
||||
model_name=model_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
elif model_name is not None:
|
||||
hosted_model = get_embedding_model(
|
||||
prefixed_texts = (
|
||||
[f"{prefix}{text}" for text in non_empty_texts]
|
||||
if prefix
|
||||
else non_empty_texts
|
||||
)
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = hosted_model.encode(
|
||||
texts, normalize_embeddings=normalize_embeddings
|
||||
embeddings = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Embeddings were not created")
|
||||
raise RuntimeError("Failed to create Embeddings")
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
embeddings_with_nulls: list[list[float] | None] = []
|
||||
current_embedding_index = 0
|
||||
|
||||
for idx in range(len(texts)):
|
||||
if idx in empty_indices:
|
||||
embeddings_with_nulls.append(None)
|
||||
else:
|
||||
embedding = embeddings[current_embedding_index]
|
||||
if isinstance(embedding, list) or embedding is None:
|
||||
embeddings_with_nulls.append(embedding)
|
||||
else:
|
||||
embeddings_with_nulls.append(embedding.tolist())
|
||||
current_embedding_index += 1
|
||||
|
||||
embeddings = embeddings_with_nulls
|
||||
return embeddings
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
@@ -252,7 +325,17 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
|
||||
try:
|
||||
if embed_request.text_type == EmbedTextType.QUERY:
|
||||
prefix = embed_request.manual_query_prefix
|
||||
elif embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
prefix = embed_request.manual_passage_prefix
|
||||
else:
|
||||
prefix = None
|
||||
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
@@ -261,13 +344,13 @@ async def process_embed_request(
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during embedding process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to run Bi-Encoder embedding"
|
||||
)
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
@@ -276,6 +359,11 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
if not embed_request.documents or not embed_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No documents or query to be reranked"
|
||||
)
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
fastapi==0.109.2
|
||||
h5py==3.9.0
|
||||
pydantic==1.10.13
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
tensorflow==2.15.0
|
||||
@@ -9,5 +10,5 @@ transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
openai==1.14.3
|
||||
cohere==5.5.8
|
||||
google-cloud-aiplatform==1.58.0
|
||||
cohere==5.6.1
|
||||
google-cloud-aiplatform==1.58.0
|
||||
|
||||
@@ -14,7 +14,10 @@ sys.path.append(parent_dir)
|
||||
# flake8: noqa: E402
|
||||
|
||||
# Now import Danswer modules
|
||||
from danswer.db.models import DocumentSet__ConnectorCredentialPair
|
||||
from danswer.db.models import (
|
||||
DocumentSet__ConnectorCredentialPair,
|
||||
UserGroup__ConnectorCredentialPair,
|
||||
)
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.index_attempt import (
|
||||
@@ -44,7 +47,7 @@ logger = setup_logger()
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def unsafe_deletion(
|
||||
def _unsafe_deletion(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
@@ -82,11 +85,22 @@ def unsafe_deletion(
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# Delete document sets + connector / credential Pairs
|
||||
# Delete document sets
|
||||
stmt = delete(DocumentSet__ConnectorCredentialPair).where(
|
||||
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == pair_id
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
# delete user group associations
|
||||
stmt = delete(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.cc_pair_id == pair_id
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
# need to flush to avoid foreign key violations
|
||||
db_session.flush()
|
||||
|
||||
# delete the actual connector credential pair
|
||||
stmt = delete(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
@@ -168,7 +182,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
files_deleted_count = unsafe_deletion(
|
||||
files_deleted_count = _unsafe_deletion(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -4,9 +4,7 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
# This already includes any prefixes, the text is just passed directly to the model
|
||||
texts: list[str]
|
||||
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None
|
||||
max_context_length: int
|
||||
@@ -14,10 +12,12 @@ class EmbedRequest(BaseModel):
|
||||
api_key: str | None
|
||||
provider_type: str | None
|
||||
text_type: EmbedTextType
|
||||
manual_query_prefix: str | None
|
||||
manual_passage_prefix: str | None
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
embeddings: list[list[float]]
|
||||
embeddings: list[list[float] | None]
|
||||
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
@@ -26,7 +26,7 @@ class RerankRequest(BaseModel):
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
scores: list[list[float]]
|
||||
scores: list[list[float] | None]
|
||||
|
||||
|
||||
class IntentRequest(BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ This Python script automates the process of running search quality tests for a b
|
||||
- Manages environment variables
|
||||
- Switches to specified Git branch
|
||||
- Uploads test documents
|
||||
- Runs search quality tests using Relari
|
||||
- Runs search quality tests
|
||||
- Cleans up Docker containers (optional)
|
||||
|
||||
## Usage
|
||||
@@ -29,9 +29,17 @@ export PYTHONPATH=$PYTHONPATH:$PWD/backend
|
||||
```
|
||||
cd backend/tests/regression/answer_quality
|
||||
```
|
||||
7. Run the script:
|
||||
7. To launch the evaluation environment, run the launch_eval_env.py script (this step can be skipped if you are running the env outside of docker, just leave "environment_name" blank):
|
||||
```
|
||||
python run_eval_pipeline.py
|
||||
python launch_eval_env.py
|
||||
```
|
||||
8. Run the file_uploader.py script to upload the zip files located at the path "zipped_documents_file"
|
||||
```
|
||||
python file_uploader.py
|
||||
```
|
||||
9. Run the run_qa.py script to ask questions from the jsonl located at the path "questions_file". This will hit the "query/answer-with-quote" API endpoint.
|
||||
```
|
||||
python run_qa.py
|
||||
```
|
||||
|
||||
Note: All data will be saved even after the containers are shut down. There are instructions below to re-launching docker containers using this data.
|
||||
@@ -61,6 +69,11 @@ Edit `search_test_config.yaml` to set:
|
||||
- Set this to true to automatically delete all docker containers, networks and volumes after the test
|
||||
- launch_web_ui
|
||||
- Set this to true if you want to use the UI during/after the testing process
|
||||
- only_state
|
||||
- Whether to only run Vespa and Postgres
|
||||
- only_retrieve_docs
|
||||
- Set true to only retrieve documents, not LLM response
|
||||
- This is to save on API costs
|
||||
- use_cloud_gpu
|
||||
- Set to true or false depending on if you want to use the remote gpu
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
@@ -70,12 +83,10 @@ Edit `search_test_config.yaml` to set:
|
||||
- model_server_port
|
||||
- This is the port of the remote model server
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
- existing_test_suffix (THIS IS NOT A SUFFIX ANYMORE, TODO UPDATE THE DOCS HERE)
|
||||
- environment_name
|
||||
- Use this if you would like to relaunch a previous test instance
|
||||
- Input the suffix of the test you'd like to re-launch
|
||||
- (E.g. to use the data from folder "test-1234-5678" put "-1234-5678")
|
||||
- No new files will automatically be uploaded
|
||||
- Leave empty to run a new test
|
||||
- Input the env_name of the test you'd like to re-launch
|
||||
- Leave empty to launch referencing local default network locations
|
||||
- limit
|
||||
- Max number of questions you'd like to ask against the dataset
|
||||
- Set to null for no limit
|
||||
@@ -85,7 +96,7 @@ Edit `search_test_config.yaml` to set:
|
||||
|
||||
## Relaunching From Existing Data
|
||||
|
||||
To launch an existing set of containers that has already completed indexing, set the existing_test_suffix variable. This will launch the docker containers mounted on the volumes of the indicated suffix and will not automatically index any documents or run any QA.
|
||||
To launch an existing set of containers that has already completed indexing, set the environment_name variable. This will launch the docker containers mounted on the volumes of the indicated env_name and will not automatically index any documents or run any QA.
|
||||
|
||||
Once these containers are launched you can run file_uploader.py or run_qa.py (assuming you have run the steps in the Usage section above).
|
||||
- file_uploader.py will upload and index additional zipped files located at the zipped_documents_file path.
|
||||
|
||||
@@ -2,45 +2,31 @@ import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def _api_url_builder(run_suffix: str, api_path: str) -> str:
|
||||
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
|
||||
|
||||
|
||||
def _create_new_chat_session(run_suffix: str) -> int:
|
||||
create_chat_request = ChatSessionCreationRequest(
|
||||
persona_id=0,
|
||||
description=None,
|
||||
)
|
||||
body = create_chat_request.dict()
|
||||
|
||||
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
|
||||
|
||||
response_json = requests.post(
|
||||
create_chat_url, headers=GENERAL_HEADERS, json=body
|
||||
).json()
|
||||
chat_session_id = response_json.get("chat_session_id")
|
||||
|
||||
if isinstance(chat_session_id, int):
|
||||
return chat_session_id
|
||||
def _api_url_builder(env_name: str, api_path: str) -> str:
|
||||
if env_name:
|
||||
return f"http://localhost:{get_api_server_host_port(env_name)}" + api_path
|
||||
else:
|
||||
raise RuntimeError(response_json)
|
||||
return "http://localhost:8080" + api_path
|
||||
|
||||
|
||||
@retry(tries=10, delay=10)
|
||||
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||
@retry(tries=5, delay=5)
|
||||
def get_answer_from_query(
|
||||
query: str, only_retrieve_docs: bool, env_name: str
|
||||
) -> tuple[list[str], str]:
|
||||
filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=None,
|
||||
@@ -48,42 +34,47 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||
tags=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
retrieval_options = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
|
||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||
|
||||
new_message_request = DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=0,
|
||||
persona_id=0,
|
||||
retrieval_options=RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
),
|
||||
chain_of_thought=False,
|
||||
return_contexts=True,
|
||||
skip_gen_ai_answer_generation=only_retrieve_docs,
|
||||
)
|
||||
|
||||
chat_session_id = _create_new_chat_session(run_suffix)
|
||||
|
||||
url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
|
||||
|
||||
new_message_request = BasicCreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
message=query,
|
||||
retrieval_options=retrieval_options,
|
||||
query_override=query,
|
||||
)
|
||||
url = _api_url_builder(env_name, "/query/answer-with-quote/")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
body = new_message_request.dict()
|
||||
body["user"] = None
|
||||
try:
|
||||
response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
|
||||
simple_search_docs = response_json.get("simple_search_docs", [])
|
||||
answer = response_json.get("answer", "")
|
||||
response_json = requests.post(url, headers=headers, json=body).json()
|
||||
context_data_list = response_json.get("contexts", {}).get("contexts", [])
|
||||
answer = response_json.get("answer", "") or ""
|
||||
except Exception as e:
|
||||
print("Failed to answer the questions:")
|
||||
print(f"\t {str(e)}")
|
||||
print("trying again")
|
||||
print("Try restarting vespa container and trying agian")
|
||||
raise e
|
||||
|
||||
return simple_search_docs, answer
|
||||
return context_data_list, answer
|
||||
|
||||
|
||||
@retry(tries=10, delay=10)
|
||||
def check_if_query_ready(run_suffix: str) -> bool:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
|
||||
def check_indexing_status(env_name: str) -> tuple[int, bool]:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector/indexing-status/")
|
||||
try:
|
||||
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
|
||||
except Exception as e:
|
||||
@@ -98,20 +89,21 @@ def check_if_query_ready(run_suffix: str) -> bool:
|
||||
status = index_attempt["last_status"]
|
||||
if status == IndexingStatus.IN_PROGRESS or status == IndexingStatus.NOT_STARTED:
|
||||
ongoing_index_attempts = True
|
||||
elif status == IndexingStatus.SUCCESS:
|
||||
doc_count += 16
|
||||
doc_count += index_attempt["docs_indexed"]
|
||||
doc_count -= 16
|
||||
|
||||
if not doc_count:
|
||||
print("No docs indexed, waiting for indexing to start")
|
||||
elif ongoing_index_attempts:
|
||||
print(
|
||||
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
|
||||
)
|
||||
|
||||
return doc_count > 0 and not ongoing_index_attempts
|
||||
# all the +16 and -16 are to account for the fact that the indexing status
|
||||
# is only updated every 16 documents and will tells us how many are
|
||||
# chunked, not indexed. probably need to fix this. in the future!
|
||||
if doc_count:
|
||||
doc_count += 16
|
||||
return doc_count, ongoing_index_attempts
|
||||
|
||||
|
||||
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
|
||||
def run_cc_once(env_name: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector/run-once/")
|
||||
body = {
|
||||
"connector_id": connector_id,
|
||||
"credential_ids": [credential_id],
|
||||
@@ -126,9 +118,9 @@ def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
print("Failed text:", response.text)
|
||||
|
||||
|
||||
def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(
|
||||
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||
env_name, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||
)
|
||||
|
||||
body = {"name": "zip_folder_contents", "is_public": True}
|
||||
@@ -141,8 +133,8 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
|
||||
print("Failed text:", response.text)
|
||||
|
||||
|
||||
def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||
url = _api_url_builder(run_suffix, "/manage/connector")
|
||||
def _get_existing_connector_names(env_name: str) -> list[str]:
|
||||
url = _api_url_builder(env_name, "/manage/connector")
|
||||
|
||||
body = {
|
||||
"credential_json": {},
|
||||
@@ -156,10 +148,10 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||
raise RuntimeError(response.__dict__)
|
||||
|
||||
|
||||
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector")
|
||||
def create_connector(env_name: str, file_paths: list[str]) -> int:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector")
|
||||
connector_name = base_connector_name = "search_eval_connector"
|
||||
existing_connector_names = _get_existing_connector_names(run_suffix)
|
||||
existing_connector_names = _get_existing_connector_names(env_name)
|
||||
|
||||
count = 1
|
||||
while connector_name in existing_connector_names:
|
||||
@@ -186,8 +178,8 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
raise RuntimeError(response.__dict__)
|
||||
|
||||
|
||||
def create_credential(run_suffix: str) -> int:
|
||||
url = _api_url_builder(run_suffix, "/manage/credential")
|
||||
def create_credential(env_name: str) -> int:
|
||||
url = _api_url_builder(env_name, "/manage/credential")
|
||||
body = {
|
||||
"credential_json": {},
|
||||
"admin_public": True,
|
||||
@@ -201,12 +193,12 @@ def create_credential(run_suffix: str) -> int:
|
||||
|
||||
|
||||
@retry(tries=10, delay=2, backoff=2)
|
||||
def upload_file(run_suffix: str, zip_file_path: str) -> list[str]:
|
||||
def upload_file(env_name: str, zip_file_path: str) -> list[str]:
|
||||
files = [
|
||||
("files", open(zip_file_path, "rb")),
|
||||
]
|
||||
|
||||
api_path = _api_url_builder(run_suffix, "/manage/admin/connector/file/upload")
|
||||
api_path = _api_url_builder(env_name, "/manage/admin/connector/file/upload")
|
||||
try:
|
||||
response = requests.post(api_path, files=files)
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
|
||||
@@ -67,20 +67,20 @@ def switch_to_commit(commit_sha: str) -> None:
|
||||
print("Repository updated successfully.")
|
||||
|
||||
|
||||
def get_docker_container_env_vars(suffix: str) -> dict:
|
||||
def get_docker_container_env_vars(env_name: str) -> dict:
|
||||
"""
|
||||
Retrieves environment variables from "background" and "api_server" Docker containers.
|
||||
"""
|
||||
print(f"Getting environment variables for containers with suffix: {suffix}")
|
||||
print(f"Getting environment variables for containers with env_name: {env_name}")
|
||||
|
||||
combined_env_vars = {}
|
||||
for container_type in ["background", "api_server"]:
|
||||
container_name = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{suffix}/'"
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{env_name}/'"
|
||||
)[0].strip()
|
||||
if not container_name:
|
||||
raise RuntimeError(
|
||||
f"No {container_type} container found with suffix: {suffix}"
|
||||
f"No {container_type} container found with env_name: {env_name}"
|
||||
)
|
||||
|
||||
env_vars_json = _run_command(
|
||||
@@ -95,9 +95,9 @@ def get_docker_container_env_vars(suffix: str) -> dict:
|
||||
return combined_env_vars
|
||||
|
||||
|
||||
def manage_data_directories(suffix: str, base_path: str, use_cloud_gpu: bool) -> None:
|
||||
def manage_data_directories(env_name: str, base_path: str, use_cloud_gpu: bool) -> None:
|
||||
# Use the user's home directory as the base path
|
||||
target_path = os.path.join(os.path.expanduser(base_path), suffix)
|
||||
target_path = os.path.join(os.path.expanduser(base_path), env_name)
|
||||
directories = {
|
||||
"DANSWER_POSTGRES_DATA_DIR": os.path.join(target_path, "postgres/"),
|
||||
"DANSWER_VESPA_DATA_DIR": os.path.join(target_path, "vespa/"),
|
||||
@@ -144,12 +144,12 @@ def _is_port_in_use(port: int) -> bool:
|
||||
|
||||
|
||||
def start_docker_compose(
|
||||
run_suffix: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
|
||||
env_name: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
|
||||
) -> None:
|
||||
print("Starting Docker Compose...")
|
||||
os.chdir(os.path.dirname(__file__))
|
||||
os.chdir("../../../../deployment/docker_compose/")
|
||||
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{run_suffix} up -d"
|
||||
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{env_name} up -d"
|
||||
command += " --build"
|
||||
command += " --force-recreate"
|
||||
|
||||
@@ -175,17 +175,17 @@ def start_docker_compose(
|
||||
print("Containers have been launched")
|
||||
|
||||
|
||||
def cleanup_docker(run_suffix: str) -> None:
|
||||
def cleanup_docker(env_name: str) -> None:
|
||||
print(
|
||||
f"Deleting Docker containers, volumes, and networks for project suffix: {run_suffix}"
|
||||
f"Deleting Docker containers, volumes, and networks for project env_name: {env_name}"
|
||||
)
|
||||
|
||||
stdout, _ = _run_command("docker ps -a --format '{{json .}}'")
|
||||
|
||||
containers = [json.loads(line) for line in stdout.splitlines()]
|
||||
if not run_suffix:
|
||||
run_suffix = datetime.now().strftime("-%Y")
|
||||
project_name = f"danswer-stack{run_suffix}"
|
||||
if not env_name:
|
||||
env_name = datetime.now().strftime("-%Y")
|
||||
project_name = f"danswer-stack{env_name}"
|
||||
containers_to_delete = [
|
||||
c for c in containers if c["Names"].startswith(project_name)
|
||||
]
|
||||
@@ -221,23 +221,23 @@ def cleanup_docker(run_suffix: str) -> None:
|
||||
|
||||
networks = stdout.splitlines()
|
||||
|
||||
networks_to_delete = [n for n in networks if run_suffix in n]
|
||||
networks_to_delete = [n for n in networks if env_name in n]
|
||||
|
||||
if not networks_to_delete:
|
||||
print(f"No networks found containing suffix: {run_suffix}")
|
||||
print(f"No networks found containing env_name: {env_name}")
|
||||
else:
|
||||
network_names = " ".join(networks_to_delete)
|
||||
_run_command(f"docker network rm {network_names}")
|
||||
|
||||
print(
|
||||
f"Successfully deleted {len(networks_to_delete)} networks containing suffix: {run_suffix}"
|
||||
f"Successfully deleted {len(networks_to_delete)} networks containing env_name: {env_name}"
|
||||
)
|
||||
|
||||
|
||||
@retry(tries=5, delay=5, backoff=2)
|
||||
def get_api_server_host_port(suffix: str) -> str:
|
||||
def get_api_server_host_port(env_name: str) -> str:
|
||||
"""
|
||||
This pulls all containers with the provided suffix
|
||||
This pulls all containers with the provided env_name
|
||||
It then grabs the JSON specific container with a name containing "api_server"
|
||||
It then grabs the port info from the JSON and strips out the relevent data
|
||||
"""
|
||||
@@ -248,16 +248,16 @@ def get_api_server_host_port(suffix: str) -> str:
|
||||
server_jsons = []
|
||||
|
||||
for container in containers:
|
||||
if container_name in container["Names"] and suffix in container["Names"]:
|
||||
if container_name in container["Names"] and env_name in container["Names"]:
|
||||
server_jsons.append(container)
|
||||
|
||||
if not server_jsons:
|
||||
raise RuntimeError(
|
||||
f"No container found containing: {container_name} and {suffix}"
|
||||
f"No container found containing: {container_name} and {env_name}"
|
||||
)
|
||||
elif len(server_jsons) > 1:
|
||||
raise RuntimeError(
|
||||
f"Too many containers matching {container_name} found, please indicate a suffix"
|
||||
f"Too many containers matching {container_name} found, please indicate a env_name"
|
||||
)
|
||||
server_json = server_jsons[0]
|
||||
|
||||
@@ -278,67 +278,37 @@ def get_api_server_host_port(suffix: str) -> str:
|
||||
raise RuntimeError(f"Too many ports matching {client_port} found")
|
||||
if not matching_ports:
|
||||
raise RuntimeError(
|
||||
f"No port found containing: {client_port} for container: {container_name} and suffix: {suffix}"
|
||||
f"No port found containing: {client_port} for container: {container_name} and env_name: {env_name}"
|
||||
)
|
||||
return matching_ports[0]
|
||||
|
||||
|
||||
# Added function to check Vespa container health status
|
||||
def is_vespa_container_healthy(suffix: str) -> bool:
|
||||
print(f"Checking health status of Vespa container for suffix: {suffix}")
|
||||
|
||||
# Find the Vespa container
|
||||
stdout, _ = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
|
||||
)
|
||||
container_name = stdout.strip()
|
||||
|
||||
if not container_name:
|
||||
print(f"No Vespa container found with suffix: {suffix}")
|
||||
return False
|
||||
|
||||
# Get the health status
|
||||
stdout, _ = _run_command(
|
||||
f"docker inspect --format='{{{{.State.Health.Status}}}}' {container_name}"
|
||||
)
|
||||
health_status = stdout.strip()
|
||||
|
||||
is_healthy = health_status.lower() == "healthy"
|
||||
print(f"Vespa container '{container_name}' health status: {health_status}")
|
||||
|
||||
return is_healthy
|
||||
|
||||
|
||||
# Added function to restart Vespa container
|
||||
def restart_vespa_container(suffix: str) -> None:
|
||||
print(f"Restarting Vespa container for suffix: {suffix}")
|
||||
def restart_vespa_container(env_name: str) -> None:
|
||||
print(f"Restarting Vespa container for env_name: {env_name}")
|
||||
|
||||
# Find the Vespa container
|
||||
stdout, _ = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/index-1/ && /{env_name}/'"
|
||||
)
|
||||
container_name = stdout.strip()
|
||||
|
||||
if not container_name:
|
||||
raise RuntimeError(f"No Vespa container found with suffix: {suffix}")
|
||||
raise RuntimeError(f"No Vespa container found with env_name: {env_name}")
|
||||
|
||||
# Restart the container
|
||||
_run_command(f"docker restart {container_name}")
|
||||
|
||||
print(f"Vespa container '{container_name}' has begun restarting")
|
||||
|
||||
time_to_wait = 5
|
||||
while not is_vespa_container_healthy(suffix):
|
||||
print(f"Waiting {time_to_wait} seconds for vespa container to restart")
|
||||
time.sleep(5)
|
||||
|
||||
time.sleep(30)
|
||||
print(f"Vespa container '{container_name}' has been restarted")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Running this just cleans up the docker environment for the container indicated by existing_test_suffix
|
||||
If no existing_test_suffix is indicated, will just clean up all danswer docker containers/volumes/networks
|
||||
Running this just cleans up the docker environment for the container indicated by environment_name
|
||||
If no environment_name is indicated, will just clean up all danswer docker containers/volumes/networks
|
||||
Note: vespa/postgres mounts are not deleted
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -348,4 +318,4 @@ if __name__ == "__main__":
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError("config must be a dictionary")
|
||||
cleanup_docker(config["existing_test_suffix"])
|
||||
cleanup_docker(config["environment_name"])
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.regression.answer_quality.api_utils import check_indexing_status
|
||||
from tests.regression.answer_quality.api_utils import create_cc_pair
|
||||
from tests.regression.answer_quality.api_utils import create_connector
|
||||
from tests.regression.answer_quality.api_utils import create_credential
|
||||
@@ -10,15 +15,65 @@ from tests.regression.answer_quality.api_utils import run_cc_once
|
||||
from tests.regression.answer_quality.api_utils import upload_file
|
||||
|
||||
|
||||
def upload_test_files(zip_file_path: str, run_suffix: str) -> None:
|
||||
def unzip_and_get_file_paths(zip_file_path: str) -> list[str]:
|
||||
persistent_dir = tempfile.mkdtemp()
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(persistent_dir)
|
||||
|
||||
return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()]
|
||||
|
||||
|
||||
def create_temp_zip_from_files(file_paths: list[str]) -> str:
|
||||
persistent_dir = tempfile.mkdtemp()
|
||||
zip_file_path = os.path.join(persistent_dir, "temp.zip")
|
||||
|
||||
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
|
||||
for file_path in file_paths:
|
||||
zip_file.write(file_path, Path(file_path).name)
|
||||
|
||||
return zip_file_path
|
||||
|
||||
|
||||
def upload_test_files(zip_file_path: str, env_name: str) -> None:
|
||||
print("zip:", zip_file_path)
|
||||
file_paths = upload_file(run_suffix, zip_file_path)
|
||||
file_paths = upload_file(env_name, zip_file_path)
|
||||
|
||||
conn_id = create_connector(run_suffix, file_paths)
|
||||
cred_id = create_credential(run_suffix)
|
||||
conn_id = create_connector(env_name, file_paths)
|
||||
cred_id = create_credential(env_name)
|
||||
|
||||
create_cc_pair(run_suffix, conn_id, cred_id)
|
||||
run_cc_once(run_suffix, conn_id, cred_id)
|
||||
create_cc_pair(env_name, conn_id, cred_id)
|
||||
run_cc_once(env_name, conn_id, cred_id)
|
||||
|
||||
|
||||
def manage_file_upload(zip_file_path: str, env_name: str) -> None:
|
||||
unzipped_file_paths = unzip_and_get_file_paths(zip_file_path)
|
||||
total_file_count = len(unzipped_file_paths)
|
||||
|
||||
while True:
|
||||
doc_count, ongoing_index_attempts = check_indexing_status(env_name)
|
||||
|
||||
if ongoing_index_attempts:
|
||||
print(
|
||||
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
|
||||
)
|
||||
elif not doc_count:
|
||||
print("No docs indexed, waiting for indexing to start")
|
||||
upload_test_files(zip_file_path, env_name)
|
||||
elif doc_count < total_file_count:
|
||||
print(f"No ongooing indexing attempts but only {doc_count} docs indexed")
|
||||
remaining_files = unzipped_file_paths[doc_count:]
|
||||
print(f"Grabbed last {len(remaining_files)} docs to try agian")
|
||||
temp_zip_file_path = create_temp_zip_from_files(remaining_files)
|
||||
upload_test_files(temp_zip_file_path, env_name)
|
||||
os.unlink(temp_zip_file_path)
|
||||
else:
|
||||
print(f"Successfully uploaded {doc_count} docs!")
|
||||
break
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
for file in unzipped_file_paths:
|
||||
os.unlink(file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -27,5 +82,5 @@ if __name__ == "__main__":
|
||||
with open(config_path, "r") as file:
|
||||
config = SimpleNamespace(**yaml.safe_load(file))
|
||||
file_location = config.zipped_documents_file
|
||||
run_suffix = config.existing_test_suffix
|
||||
upload_test_files(file_location, run_suffix)
|
||||
env_name = config.environment_name
|
||||
manage_file_upload(file_location, env_name)
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.regression.answer_quality.cli_utils import cleanup_docker
|
||||
from tests.regression.answer_quality.cli_utils import manage_data_directories
|
||||
from tests.regression.answer_quality.cli_utils import set_env_variables
|
||||
from tests.regression.answer_quality.cli_utils import start_docker_compose
|
||||
from tests.regression.answer_quality.cli_utils import switch_to_commit
|
||||
from tests.regression.answer_quality.file_uploader import upload_test_files
|
||||
from tests.regression.answer_quality.run_qa import run_qa_test_and_save_results
|
||||
|
||||
|
||||
def load_config(config_filename: str) -> SimpleNamespace:
|
||||
@@ -22,12 +18,16 @@ def load_config(config_filename: str) -> SimpleNamespace:
|
||||
|
||||
def main() -> None:
|
||||
config = load_config("search_test_config.yaml")
|
||||
if config.existing_test_suffix:
|
||||
run_suffix = config.existing_test_suffix
|
||||
print("launching danswer with existing data suffix:", run_suffix)
|
||||
if config.environment_name:
|
||||
env_name = config.environment_name
|
||||
print("launching danswer with environment name:", env_name)
|
||||
else:
|
||||
run_suffix = datetime.now().strftime("-%Y%m%d-%H%M%S")
|
||||
print("run_suffix:", run_suffix)
|
||||
print("No env name defined. Not launching docker.")
|
||||
print(
|
||||
"Please define a name in the config yaml to start a new env "
|
||||
"or use an existing env"
|
||||
)
|
||||
return
|
||||
|
||||
set_env_variables(
|
||||
config.model_server_ip,
|
||||
@@ -35,22 +35,14 @@ def main() -> None:
|
||||
config.use_cloud_gpu,
|
||||
config.llm,
|
||||
)
|
||||
manage_data_directories(run_suffix, config.output_folder, config.use_cloud_gpu)
|
||||
manage_data_directories(env_name, config.output_folder, config.use_cloud_gpu)
|
||||
if config.commit_sha:
|
||||
switch_to_commit(config.commit_sha)
|
||||
|
||||
start_docker_compose(
|
||||
run_suffix, config.launch_web_ui, config.use_cloud_gpu, config.only_state
|
||||
env_name, config.launch_web_ui, config.use_cloud_gpu, config.only_state
|
||||
)
|
||||
|
||||
if not config.existing_test_suffix and not config.only_state:
|
||||
upload_test_files(config.zipped_documents_file, run_suffix)
|
||||
|
||||
run_qa_test_and_save_results(run_suffix)
|
||||
|
||||
if config.clean_up_docker_containers:
|
||||
cleanup_docker(run_suffix)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,7 +6,6 @@ import time
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.regression.answer_quality.api_utils import check_if_query_ready
|
||||
from tests.regression.answer_quality.api_utils import get_answer_from_query
|
||||
from tests.regression.answer_quality.cli_utils import get_current_commit_sha
|
||||
from tests.regression.answer_quality.cli_utils import get_docker_container_env_vars
|
||||
@@ -44,12 +43,12 @@ def _read_questions_jsonl(questions_file_path: str) -> list[dict]:
|
||||
|
||||
def _get_test_output_folder(config: dict) -> str:
|
||||
base_output_folder = os.path.expanduser(config["output_folder"])
|
||||
if config["run_suffix"]:
|
||||
if config["env_name"]:
|
||||
base_output_folder = os.path.join(
|
||||
base_output_folder, config["run_suffix"], "evaluations_output"
|
||||
base_output_folder, config["env_name"], "evaluations_output"
|
||||
)
|
||||
else:
|
||||
base_output_folder = os.path.join(base_output_folder, "no_defined_suffix")
|
||||
base_output_folder = os.path.join(base_output_folder, "no_defined_env_name")
|
||||
|
||||
counter = 1
|
||||
output_folder_path = os.path.join(base_output_folder, "run_1")
|
||||
@@ -73,12 +72,12 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
||||
|
||||
metadata = {
|
||||
"commit_sha": get_current_commit_sha(),
|
||||
"run_suffix": config["run_suffix"],
|
||||
"env_name": config["env_name"],
|
||||
"test_config": config,
|
||||
"number_of_questions_in_dataset": len(questions),
|
||||
}
|
||||
|
||||
env_vars = get_docker_container_env_vars(config["run_suffix"])
|
||||
env_vars = get_docker_container_env_vars(config["env_name"])
|
||||
if env_vars["ENV_SEED_CONFIGURATION"]:
|
||||
del env_vars["ENV_SEED_CONFIGURATION"]
|
||||
if env_vars["GPG_KEY"]:
|
||||
@@ -118,7 +117,8 @@ def _process_question(question_data: dict, config: dict, question_number: int) -
|
||||
print(f"query: {query}")
|
||||
context_data_list, answer = get_answer_from_query(
|
||||
query=query,
|
||||
run_suffix=config["run_suffix"],
|
||||
only_retrieve_docs=config["only_retrieve_docs"],
|
||||
env_name=config["env_name"],
|
||||
)
|
||||
|
||||
if not context_data_list:
|
||||
@@ -142,27 +142,23 @@ def _process_and_write_query_results(config: dict) -> None:
|
||||
test_output_folder, questions = _initialize_files(config)
|
||||
print("saving test results to folder:", test_output_folder)
|
||||
|
||||
while not check_if_query_ready(config["run_suffix"]):
|
||||
time.sleep(5)
|
||||
|
||||
if config["limit"] is not None:
|
||||
questions = questions[: config["limit"]]
|
||||
|
||||
with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
|
||||
# Use multiprocessing to process questions
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.starmap(
|
||||
_process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
|
||||
_process_question,
|
||||
[(question, config, i + 1) for i, question in enumerate(questions)],
|
||||
)
|
||||
|
||||
_populate_results_file(test_output_folder, results)
|
||||
|
||||
invalid_answer_count = 0
|
||||
for result in results:
|
||||
if not result.get("answer"):
|
||||
if len(result["context_data_list"]) == 0:
|
||||
invalid_answer_count += 1
|
||||
|
||||
if not result.get("context_data_list"):
|
||||
raise RuntimeError("Search failed, this is a critical failure!")
|
||||
|
||||
_update_metadata_file(test_output_folder, invalid_answer_count)
|
||||
|
||||
if invalid_answer_count:
|
||||
@@ -177,7 +173,7 @@ def _process_and_write_query_results(config: dict) -> None:
|
||||
print("saved test results to folder:", test_output_folder)
|
||||
|
||||
|
||||
def run_qa_test_and_save_results(run_suffix: str = "") -> None:
|
||||
def run_qa_test_and_save_results(env_name: str = "") -> None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(current_dir, "search_test_config.yaml")
|
||||
with open(config_path, "r") as file:
|
||||
@@ -186,16 +182,16 @@ def run_qa_test_and_save_results(run_suffix: str = "") -> None:
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError("config must be a dictionary")
|
||||
|
||||
if not run_suffix:
|
||||
run_suffix = config["existing_test_suffix"]
|
||||
if not env_name:
|
||||
env_name = config["environment_name"]
|
||||
|
||||
config["run_suffix"] = run_suffix
|
||||
config["env_name"] = env_name
|
||||
_process_and_write_query_results(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
To run a different set of questions, update the questions_file in search_test_config.yaml
|
||||
If there is more than one instance of Danswer running, specify the suffix in search_test_config.yaml
|
||||
If there is more than one instance of Danswer running, specify the env_name in search_test_config.yaml
|
||||
"""
|
||||
run_qa_test_and_save_results()
|
||||
|
||||
@@ -13,14 +13,11 @@ questions_file: "~/sample_questions.yaml"
|
||||
# Git commit SHA to use (null means use current code as is)
|
||||
commit_sha: null
|
||||
|
||||
# Whether to remove Docker containers after the test
|
||||
clean_up_docker_containers: true
|
||||
|
||||
# Whether to launch a web UI for the test
|
||||
launch_web_ui: false
|
||||
|
||||
# Whether to only run Vespa and Postgres
|
||||
only_state: false
|
||||
# Only retrieve documents, not LLM response
|
||||
only_retrieve_docs: false
|
||||
|
||||
# Whether to use a cloud GPU for processing
|
||||
use_cloud_gpu: false
|
||||
@@ -31,9 +28,8 @@ model_server_ip: "PUT_PUBLIC_CLOUD_IP_HERE"
|
||||
# Port of the model server (placeholder)
|
||||
model_server_port: "PUT_PUBLIC_CLOUD_PORT_HERE"
|
||||
|
||||
# Suffix for existing test results (E.g. -1234-5678)
|
||||
# empty string means no suffix
|
||||
existing_test_suffix: ""
|
||||
# Name for existing testing env (empty string uses default ports)
|
||||
environment_name: ""
|
||||
|
||||
# Limit on number of tests to run (null means no limit)
|
||||
limit: null
|
||||
|
||||
@@ -36,8 +36,7 @@ export function ModelSelectionConfirmationModal({
|
||||
at least 16GB of RAM to Danswer during this process.
|
||||
</Text>
|
||||
|
||||
{/* TODO Change this back- ensure functional */}
|
||||
{!isCustom && (
|
||||
{isCustom && (
|
||||
<Callout title="IMPORTANT" color="yellow" className="mt-4">
|
||||
We've detected that this is a custom-specified embedding
|
||||
model. Since we have to download the model files before verifying
|
||||
|
||||
@@ -149,6 +149,7 @@ function Main() {
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setShowTentativeOpenProvider(null);
|
||||
setShowTentativeModel(null);
|
||||
mutate("/api/secondary-index/get-secondary-embedding-model");
|
||||
if (!connectors || !connectors.length) {
|
||||
@@ -274,6 +275,7 @@ function Main() {
|
||||
onClose={() => setAlreadySelectedModel(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showTentativeOpenProvider && (
|
||||
<ModelSelectionConfirmationModal
|
||||
selectedModel={showTentativeOpenProvider}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import {
|
||||
ArrayHelpers,
|
||||
ErrorMessage,
|
||||
@@ -15,11 +16,16 @@ import {
|
||||
SubLabel,
|
||||
TextArrayField,
|
||||
TextFormField,
|
||||
BooleanFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { useState } from "react";
|
||||
import { Bubble } from "@/components/Bubble";
|
||||
import { GroupsIcon } from "@/components/icons/icons";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { FullLLMProvider } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
|
||||
@@ -44,9 +50,16 @@ export function CustomLLMProviderUpdateForm({
|
||||
}) {
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
// EE only
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
name: existingLlmProvider?.name ?? "",
|
||||
@@ -61,6 +74,8 @@ export function CustomLLMProviderUpdateForm({
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
? Object.entries(existingLlmProvider.custom_config)
|
||||
: [],
|
||||
is_public: existingLlmProvider?.is_public ?? true,
|
||||
groups: existingLlmProvider?.groups ?? [],
|
||||
};
|
||||
|
||||
// Setup validation schema if required
|
||||
@@ -74,6 +89,9 @@ export function CustomLLMProviderUpdateForm({
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
fast_default_model_name: Yup.string().nullable(),
|
||||
custom_config_list: Yup.array(),
|
||||
// EE Only
|
||||
is_public: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
});
|
||||
|
||||
return (
|
||||
@@ -97,6 +115,9 @@ export function CustomLLMProviderUpdateForm({
|
||||
return;
|
||||
}
|
||||
|
||||
// don't set groups if marked as public
|
||||
const groups = values.is_public ? [] : values.groups;
|
||||
|
||||
// test the configuration
|
||||
if (!isEqual(values, initialValues)) {
|
||||
setIsTesting(true);
|
||||
@@ -188,93 +209,97 @@ export function CustomLLMProviderUpdateForm({
|
||||
setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ values }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="name"
|
||||
label="Display Name"
|
||||
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
||||
placeholder="Display Name"
|
||||
/>
|
||||
{({ values, setFieldValue }) => {
|
||||
return (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="name"
|
||||
label="Display Name"
|
||||
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
||||
placeholder="Display Name"
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
<Divider />
|
||||
|
||||
<TextFormField
|
||||
name="provider"
|
||||
label="Provider Name"
|
||||
subtext={
|
||||
<TextFormField
|
||||
name="provider"
|
||||
label="Provider Name"
|
||||
subtext={
|
||||
<>
|
||||
Should be one of the providers listed at{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://docs.litellm.ai/docs/providers"
|
||||
className="text-link"
|
||||
>
|
||||
https://docs.litellm.ai/docs/providers
|
||||
</a>
|
||||
.
|
||||
</>
|
||||
}
|
||||
placeholder="Name of the custom provider"
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<SubLabel>
|
||||
Fill in the following as is needed. Refer to the LiteLLM
|
||||
documentation for the model provider name specified above in order
|
||||
to determine which fields are required.
|
||||
</SubLabel>
|
||||
|
||||
<TextFormField
|
||||
name="api_key"
|
||||
label="[Optional] API Key"
|
||||
placeholder="API Key"
|
||||
type="password"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
label="[Optional] API Base"
|
||||
placeholder="API Base"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_version"
|
||||
label="[Optional] API Version"
|
||||
placeholder="API Version"
|
||||
/>
|
||||
|
||||
<Label>[Optional] Custom Configs</Label>
|
||||
<SubLabel>
|
||||
<>
|
||||
Should be one of the providers listed at{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://docs.litellm.ai/docs/providers"
|
||||
className="text-link"
|
||||
>
|
||||
https://docs.litellm.ai/docs/providers
|
||||
</a>
|
||||
.
|
||||
<div>
|
||||
Additional configurations needed by the model provider. Are
|
||||
passed to litellm via environment variables.
|
||||
</div>
|
||||
|
||||
<div className="mt-2">
|
||||
For example, when configuring the Cloudflare provider, you
|
||||
would need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
|
||||
Cloudflare account ID as the value.
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
placeholder="Name of the custom provider"
|
||||
/>
|
||||
</SubLabel>
|
||||
|
||||
<Divider />
|
||||
|
||||
<SubLabel>
|
||||
Fill in the following as is needed. Refer to the LiteLLM
|
||||
documentation for the model provider name specified above in order
|
||||
to determine which fields are required.
|
||||
</SubLabel>
|
||||
|
||||
<TextFormField
|
||||
name="api_key"
|
||||
label="[Optional] API Key"
|
||||
placeholder="API Key"
|
||||
type="password"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
label="[Optional] API Base"
|
||||
placeholder="API Base"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_version"
|
||||
label="[Optional] API Version"
|
||||
placeholder="API Version"
|
||||
/>
|
||||
|
||||
<Label>[Optional] Custom Configs</Label>
|
||||
<SubLabel>
|
||||
<>
|
||||
<div>
|
||||
Additional configurations needed by the model provider. Are
|
||||
passed to litellm via environment variables.
|
||||
</div>
|
||||
|
||||
<div className="mt-2">
|
||||
For example, when configuring the Cloudflare provider, you would
|
||||
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
|
||||
Cloudflare account ID as the value.
|
||||
</div>
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<FieldArray
|
||||
name="custom_config_list"
|
||||
render={(arrayHelpers: ArrayHelpers<any[]>) => (
|
||||
<div>
|
||||
{values.custom_config_list.map((_, index) => {
|
||||
return (
|
||||
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
|
||||
<div className="flex">
|
||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||
<div>
|
||||
<Label>Key</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
className={`
|
||||
<FieldArray
|
||||
name="custom_config_list"
|
||||
render={(arrayHelpers: ArrayHelpers<any[]>) => (
|
||||
<div>
|
||||
{values.custom_config_list.map((_, index) => {
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className={index === 0 ? "mt-2" : "mt-6"}
|
||||
>
|
||||
<div className="flex">
|
||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||
<div>
|
||||
<Label>Key</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
@@ -284,20 +309,20 @@ export function CustomLLMProviderUpdateForm({
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mt-3">
|
||||
<Label>Value</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
className={`
|
||||
<div className="mt-3">
|
||||
<Label>Value</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
@@ -307,121 +332,190 @@ export function CustomLLMProviderUpdateForm({
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
<FiX
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||
onClick={() => arrayHelpers.remove(index)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
<FiX
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||
onClick={() => arrayHelpers.remove(index)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
);
|
||||
})}
|
||||
|
||||
<Button
|
||||
onClick={() => {
|
||||
arrayHelpers.push(["", ""]);
|
||||
}}
|
||||
className="mt-3"
|
||||
color="green"
|
||||
size="xs"
|
||||
type="button"
|
||||
icon={FiPlus}
|
||||
>
|
||||
Add New
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<Button
|
||||
onClick={() => {
|
||||
arrayHelpers.push(["", ""]);
|
||||
}}
|
||||
className="mt-3"
|
||||
color="green"
|
||||
size="xs"
|
||||
type="button"
|
||||
icon={FiPlus}
|
||||
>
|
||||
Add New
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
<Divider />
|
||||
|
||||
<TextArrayField
|
||||
name="model_names"
|
||||
label="Model Names"
|
||||
values={values}
|
||||
subtext={`List the individual models that you want to make
|
||||
<TextArrayField
|
||||
name="model_names"
|
||||
label="Model Names"
|
||||
values={values}
|
||||
subtext={`List the individual models that you want to make
|
||||
available as a part of this provider. At least one must be specified.
|
||||
As an example, for OpenAI one model might be "gpt-4".`}
|
||||
/>
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
<Divider />
|
||||
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext={`
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext={`
|
||||
The model to use by default for this provider unless
|
||||
otherwise specified. Must be one of the models listed
|
||||
above.`}
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If not set, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
<Divider />
|
||||
|
||||
<div>
|
||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||
{testError && <Text className="text-error mt-2">{testError}</Text>}
|
||||
<AdvancedOptionsToggle
|
||||
showAdvancedOptions={showAdvancedOptions}
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
/>
|
||||
|
||||
<div className="flex w-full mt-4">
|
||||
<Button type="submit" size="xs">
|
||||
{isTesting ? (
|
||||
<LoadingAnimation text="Testing" />
|
||||
) : existingLlmProvider ? (
|
||||
"Update"
|
||||
) : (
|
||||
"Enable"
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
{isPaidEnterpriseFeaturesEnabled && userGroups && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
small
|
||||
noPadding
|
||||
alignTop
|
||||
name="is_public"
|
||||
label="Is Public?"
|
||||
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
|
||||
/>
|
||||
|
||||
{userGroups &&
|
||||
userGroups.length > 0 &&
|
||||
!values.is_public && (
|
||||
<div>
|
||||
<Text>
|
||||
Select which User Groups should have access to this
|
||||
LLM Provider.
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-2 mt-2">
|
||||
{userGroups.map((userGroup) => {
|
||||
const isSelected = values.groups.includes(
|
||||
userGroup.id
|
||||
);
|
||||
return (
|
||||
<Bubble
|
||||
key={userGroup.id}
|
||||
isSelected={isSelected}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
setFieldValue(
|
||||
"groups",
|
||||
values.groups.filter(
|
||||
(id) => id !== userGroup.id
|
||||
)
|
||||
);
|
||||
} else {
|
||||
setFieldValue("groups", [
|
||||
...values.groups,
|
||||
userGroup.id,
|
||||
]);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex">
|
||||
<GroupsIcon />
|
||||
<div className="ml-1">{userGroup.name}</div>
|
||||
</div>
|
||||
</Bubble>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
{existingLlmProvider && (
|
||||
<Button
|
||||
type="button"
|
||||
color="red"
|
||||
className="ml-3"
|
||||
size="xs"
|
||||
icon={FiTrash}
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
alert(`Failed to delete provider: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
</>
|
||||
)}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
<div>
|
||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||
{testError && (
|
||||
<Text className="text-error mt-2">{testError}</Text>
|
||||
)}
|
||||
|
||||
<div className="flex w-full mt-4">
|
||||
<Button type="submit" size="xs">
|
||||
{isTesting ? (
|
||||
<LoadingAnimation text="Testing" />
|
||||
) : existingLlmProvider ? (
|
||||
"Update"
|
||||
) : (
|
||||
"Enable"
|
||||
)}
|
||||
</Button>
|
||||
{existingLlmProvider && (
|
||||
<Button
|
||||
type="button"
|
||||
color="red"
|
||||
className="ml-3"
|
||||
size="xs"
|
||||
icon={FiTrash}
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
alert(`Failed to delete provider: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Form>
|
||||
);
|
||||
}}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { Form, Formik } from "formik";
|
||||
import { FiTrash } from "react-icons/fi";
|
||||
@@ -6,11 +7,16 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import {
|
||||
SelectorFormField,
|
||||
TextFormField,
|
||||
BooleanFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { useState } from "react";
|
||||
import { Bubble } from "@/components/Bubble";
|
||||
import { GroupsIcon } from "@/components/icons/icons";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
|
||||
@@ -29,9 +35,16 @@ export function LLMProviderUpdateForm({
|
||||
}) {
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
// EE only
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
name: existingLlmProvider?.name ?? "",
|
||||
@@ -54,6 +67,8 @@ export function LLMProviderUpdateForm({
|
||||
},
|
||||
{} as { [key: string]: string }
|
||||
),
|
||||
is_public: existingLlmProvider?.is_public ?? true,
|
||||
groups: existingLlmProvider?.groups ?? [],
|
||||
};
|
||||
|
||||
const [validatedConfig, setValidatedConfig] = useState(
|
||||
@@ -91,6 +106,9 @@ export function LLMProviderUpdateForm({
|
||||
: {}),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
fast_default_model_name: Yup.string().nullable(),
|
||||
// EE Only
|
||||
is_public: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
});
|
||||
|
||||
return (
|
||||
@@ -193,7 +211,7 @@ export function LLMProviderUpdateForm({
|
||||
setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ values }) => (
|
||||
{({ values, setFieldValue }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="name"
|
||||
@@ -293,6 +311,69 @@ export function LLMProviderUpdateForm({
|
||||
|
||||
<Divider />
|
||||
|
||||
<AdvancedOptionsToggle
|
||||
showAdvancedOptions={showAdvancedOptions}
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
/>
|
||||
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
{isPaidEnterpriseFeaturesEnabled && userGroups && (
|
||||
<>
|
||||
<BooleanFormField
|
||||
small
|
||||
noPadding
|
||||
alignTop
|
||||
name="is_public"
|
||||
label="Is Public?"
|
||||
subtext="If set, this LLM Provider will be available to all users. If not, only the specified User Groups will be able to use it."
|
||||
/>
|
||||
|
||||
{userGroups && userGroups.length > 0 && !values.is_public && (
|
||||
<div>
|
||||
<Text>
|
||||
Select which User Groups should have access to this LLM
|
||||
Provider.
|
||||
</Text>
|
||||
<div className="flex flex-wrap gap-2 mt-2">
|
||||
{userGroups.map((userGroup) => {
|
||||
const isSelected = values.groups.includes(
|
||||
userGroup.id
|
||||
);
|
||||
return (
|
||||
<Bubble
|
||||
key={userGroup.id}
|
||||
isSelected={isSelected}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
setFieldValue(
|
||||
"groups",
|
||||
values.groups.filter(
|
||||
(id) => id !== userGroup.id
|
||||
)
|
||||
);
|
||||
} else {
|
||||
setFieldValue("groups", [
|
||||
...values.groups,
|
||||
userGroup.id,
|
||||
]);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex">
|
||||
<GroupsIcon />
|
||||
<div className="ml-1">{userGroup.name}</div>
|
||||
</div>
|
||||
</Bubble>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div>
|
||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||
{testError && <Text className="text-error mt-2">{testError}</Text>}
|
||||
|
||||
@@ -17,6 +17,8 @@ export interface WellKnownLLMProviderDescriptor {
|
||||
llm_names: string[];
|
||||
default_model: string | null;
|
||||
default_fast_model: string | null;
|
||||
is_public: boolean;
|
||||
groups: number[];
|
||||
}
|
||||
|
||||
export interface LLMProvider {
|
||||
@@ -28,6 +30,8 @@ export interface LLMProvider {
|
||||
custom_config: { [key: string]: string } | null;
|
||||
default_model_name: string;
|
||||
fast_default_model_name: string | null;
|
||||
is_public: boolean;
|
||||
groups: number[];
|
||||
}
|
||||
|
||||
export interface FullLLMProvider extends LLMProvider {
|
||||
@@ -44,4 +48,6 @@ export interface LLMProviderDescriptor {
|
||||
default_model_name: string;
|
||||
fast_default_model_name: string | null;
|
||||
is_default_provider: boolean | null;
|
||||
is_public: boolean;
|
||||
groups: number[];
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ export default async function GalleryPage({
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets: documentSets,
|
||||
availablePersonas: assistants,
|
||||
availableAssistants: assistants,
|
||||
availableTags: tags,
|
||||
llmProviders,
|
||||
folders,
|
||||
|
||||
@@ -52,7 +52,7 @@ export default async function GalleryPage({
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets: documentSets,
|
||||
availablePersonas: assistants,
|
||||
availableAssistants: assistants,
|
||||
availableTags: tags,
|
||||
llmProviders,
|
||||
folders,
|
||||
|
||||
@@ -14,15 +14,18 @@ export function ChatBanner() {
|
||||
return (
|
||||
<div
|
||||
className={`
|
||||
z-[39]
|
||||
h-[30px]
|
||||
bg-background-100
|
||||
shadow-sm
|
||||
m-2
|
||||
rounded
|
||||
border-border
|
||||
border
|
||||
flex`}
|
||||
mt-8
|
||||
mb-2
|
||||
mx-2
|
||||
z-[39]
|
||||
w-full
|
||||
h-[30px]
|
||||
bg-background-100
|
||||
shadow-sm
|
||||
rounded
|
||||
border-border
|
||||
border
|
||||
flex`}
|
||||
>
|
||||
<div className="mx-auto text-emphasis text-sm flex flex-col">
|
||||
<div className="my-auto">
|
||||
|
||||
@@ -63,7 +63,6 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import Dropzone from "react-dropzone";
|
||||
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
|
||||
import { ChatInputBar } from "./input/ChatInputBar";
|
||||
import { ConfigurationModal } from "./modal/configuration/ConfigurationModal";
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { orderAssistantsForUser } from "@/lib/assistants/orderAssistants";
|
||||
@@ -82,38 +81,29 @@ const SYSTEM_MESSAGE_ID = -3;
|
||||
export function ChatPage({
|
||||
toggle,
|
||||
documentSidebarInitialWidth,
|
||||
defaultSelectedPersonaId,
|
||||
defaultSelectedAssistantId,
|
||||
toggledSidebar,
|
||||
}: {
|
||||
toggle: () => void;
|
||||
documentSidebarInitialWidth?: number;
|
||||
defaultSelectedPersonaId?: number;
|
||||
defaultSelectedAssistantId?: number;
|
||||
toggledSidebar: boolean;
|
||||
}) {
|
||||
const [configModalActiveTab, setConfigModalActiveTab] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
let {
|
||||
user,
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets,
|
||||
availablePersonas,
|
||||
availableAssistants,
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
} = useChatContext();
|
||||
|
||||
const filteredAssistants = orderAssistantsForUser(availablePersonas, user);
|
||||
|
||||
const [selectedAssistant, setSelectedAssistant] = useState<Persona | null>(
|
||||
null
|
||||
);
|
||||
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
|
||||
useState<Persona | null>(null);
|
||||
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
// chat session
|
||||
const existingChatIdRaw = searchParams.get("chatId");
|
||||
const existingChatSessionId = existingChatIdRaw
|
||||
? parseInt(existingChatIdRaw)
|
||||
@@ -123,9 +113,51 @@ export function ChatPage({
|
||||
);
|
||||
const chatSessionIdRef = useRef<number | null>(existingChatSessionId);
|
||||
|
||||
// LLM
|
||||
const llmOverrideManager = useLlmOverride(selectedChatSession);
|
||||
|
||||
const existingChatSessionPersonaId = selectedChatSession?.persona_id;
|
||||
// Assistants
|
||||
const filteredAssistants = orderAssistantsForUser(availableAssistants, user);
|
||||
|
||||
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
|
||||
const [selectedAssistant, setSelectedAssistant] = useState<
|
||||
Persona | undefined
|
||||
>(
|
||||
// NOTE: look through available assistants here, so that even if the user
|
||||
// has hidden this assistant it still shows the correct assistant when
|
||||
// going back to an old chat session
|
||||
existingChatSessionAssistantId !== undefined
|
||||
? availableAssistants.find(
|
||||
(assistant) => assistant.id === existingChatSessionAssistantId
|
||||
)
|
||||
: defaultSelectedAssistantId !== undefined
|
||||
? availableAssistants.find(
|
||||
(assistant) => assistant.id === defaultSelectedAssistantId
|
||||
)
|
||||
: undefined
|
||||
);
|
||||
const setSelectedAssistantFromId = (assistantId: number) => {
|
||||
// NOTE: also intentionally look through available assistants here, so that
|
||||
// even if the user has hidden an assistant they can still go back to it
|
||||
// for old chats
|
||||
setSelectedAssistant(
|
||||
availableAssistants.find((assistant) => assistant.id === assistantId)
|
||||
);
|
||||
};
|
||||
const liveAssistant =
|
||||
selectedAssistant || filteredAssistants[0] || availableAssistants[0];
|
||||
|
||||
// this is for "@"ing assistants
|
||||
const [alternativeAssistant, setAlternativeAssistant] =
|
||||
useState<Persona | null>(null);
|
||||
|
||||
// this is used to track which assistant is being used to generate the current message
|
||||
// for example, this would come into play when:
|
||||
// 1. default assistant is `Danswer`
|
||||
// 2. we "@"ed the `GPT` assistant and sent a message
|
||||
// 3. while the `GPT` assistant message is generating, we "@" the `Paraphrase` assistant
|
||||
const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] =
|
||||
useState<Persona | null>(null);
|
||||
|
||||
// used to track whether or not the initial "submit on load" has been performed
|
||||
// this only applies if `?submit-on-load=true` or `?submit-on-load=1` is in the URL
|
||||
@@ -182,14 +214,10 @@ export function ChatPage({
|
||||
async function initialSessionFetch() {
|
||||
if (existingChatSessionId === null) {
|
||||
setIsFetchingChatMessages(false);
|
||||
if (defaultSelectedPersonaId !== undefined) {
|
||||
setSelectedPersona(
|
||||
filteredAssistants.find(
|
||||
(persona) => persona.id === defaultSelectedPersonaId
|
||||
)
|
||||
);
|
||||
if (defaultSelectedAssistantId !== undefined) {
|
||||
setSelectedAssistantFromId(defaultSelectedAssistantId);
|
||||
} else {
|
||||
setSelectedPersona(undefined);
|
||||
setSelectedAssistant(undefined);
|
||||
}
|
||||
setCompleteMessageDetail({
|
||||
sessionId: null,
|
||||
@@ -214,12 +242,7 @@ export function ChatPage({
|
||||
);
|
||||
|
||||
const chatSession = (await response.json()) as BackendChatSession;
|
||||
|
||||
setSelectedPersona(
|
||||
filteredAssistants.find(
|
||||
(persona) => persona.id === chatSession.persona_id
|
||||
)
|
||||
);
|
||||
setSelectedAssistantFromId(chatSession.persona_id);
|
||||
|
||||
const newMessageMap = processRawChatHistory(chatSession.messages);
|
||||
const newMessageHistory = buildLatestMessageChain(newMessageMap);
|
||||
@@ -373,32 +396,18 @@ export function ChatPage({
|
||||
)
|
||||
: { aiMessage: null };
|
||||
|
||||
const [selectedPersona, setSelectedPersona] = useState<Persona | undefined>(
|
||||
existingChatSessionPersonaId !== undefined
|
||||
? filteredAssistants.find(
|
||||
(persona) => persona.id === existingChatSessionPersonaId
|
||||
)
|
||||
: defaultSelectedPersonaId !== undefined
|
||||
? filteredAssistants.find(
|
||||
(persona) => persona.id === defaultSelectedPersonaId
|
||||
)
|
||||
: undefined
|
||||
);
|
||||
const livePersona =
|
||||
selectedPersona || filteredAssistants[0] || availablePersonas[0];
|
||||
|
||||
const [chatSessionSharedStatus, setChatSessionSharedStatus] =
|
||||
useState<ChatSessionSharedStatus>(ChatSessionSharedStatus.Private);
|
||||
|
||||
useEffect(() => {
|
||||
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
|
||||
setSelectedPersona(
|
||||
setSelectedAssistant(
|
||||
filteredAssistants.find(
|
||||
(persona) => persona.id === defaultSelectedPersonaId
|
||||
(persona) => persona.id === defaultSelectedAssistantId
|
||||
)
|
||||
);
|
||||
}
|
||||
}, [defaultSelectedPersonaId]);
|
||||
}, [defaultSelectedAssistantId]);
|
||||
|
||||
const [
|
||||
selectedDocuments,
|
||||
@@ -414,7 +423,7 @@ export function ChatPage({
|
||||
useEffect(() => {
|
||||
async function fetchMaxTokens() {
|
||||
const response = await fetch(
|
||||
`/api/chat/max-selected-document-tokens?persona_id=${livePersona.id}`
|
||||
`/api/chat/max-selected-document-tokens?persona_id=${liveAssistant.id}`
|
||||
);
|
||||
if (response.ok) {
|
||||
const maxTokens = (await response.json()).max_tokens as number;
|
||||
@@ -423,12 +432,12 @@ export function ChatPage({
|
||||
}
|
||||
|
||||
fetchMaxTokens();
|
||||
}, [livePersona]);
|
||||
}, [liveAssistant]);
|
||||
|
||||
const filterManager = useFilters();
|
||||
const [finalAvailableSources, finalAvailableDocumentSets] =
|
||||
computeAvailableFilters({
|
||||
selectedPersona,
|
||||
selectedPersona: selectedAssistant,
|
||||
availableSources,
|
||||
availableDocumentSets,
|
||||
});
|
||||
@@ -624,16 +633,16 @@ export function ChatPage({
|
||||
queryOverride,
|
||||
forceSearch,
|
||||
isSeededChat,
|
||||
alternativeAssistant = null,
|
||||
alternativeAssistantOverride = null,
|
||||
}: {
|
||||
messageIdToResend?: number;
|
||||
messageOverride?: string;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
isSeededChat?: boolean;
|
||||
alternativeAssistant?: Persona | null;
|
||||
alternativeAssistantOverride?: Persona | null;
|
||||
} = {}) => {
|
||||
setAlternativeGeneratingAssistant(alternativeAssistant);
|
||||
setAlternativeGeneratingAssistant(alternativeAssistantOverride);
|
||||
|
||||
clientScrollToBottom();
|
||||
let currChatSessionId: number;
|
||||
@@ -643,7 +652,7 @@ export function ChatPage({
|
||||
|
||||
if (isNewSession) {
|
||||
currChatSessionId = await createChatSession(
|
||||
livePersona?.id || 0,
|
||||
liveAssistant?.id || 0,
|
||||
searchParamBasedChatSessionName
|
||||
);
|
||||
} else {
|
||||
@@ -721,9 +730,9 @@ export function ChatPage({
|
||||
parentMessage = frozenMessageMap.get(SYSTEM_MESSAGE_ID) || null;
|
||||
}
|
||||
|
||||
const currentAssistantId = alternativeAssistant
|
||||
? alternativeAssistant.id
|
||||
: selectedAssistant?.id;
|
||||
const currentAssistantId = alternativeAssistantOverride
|
||||
? alternativeAssistantOverride.id
|
||||
: alternativeAssistant?.id || liveAssistant.id;
|
||||
|
||||
resetInputBar();
|
||||
|
||||
@@ -751,7 +760,7 @@ export function ChatPage({
|
||||
fileDescriptors: currentMessageFiles,
|
||||
parentMessageId: lastSuccessfulMessageId,
|
||||
chatSessionId: currChatSessionId,
|
||||
promptId: livePersona?.prompts[0]?.id || 0,
|
||||
promptId: liveAssistant?.prompts[0]?.id || 0,
|
||||
filters: buildFilters(
|
||||
filterManager.selectedSources,
|
||||
filterManager.selectedDocumentSets,
|
||||
@@ -868,7 +877,7 @@ export function ChatPage({
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCalls: finalMessage?.tool_calls || toolCalls,
|
||||
parentMessageId: newUserMessageId,
|
||||
alternateAssistantID: selectedAssistant?.id,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
},
|
||||
]);
|
||||
}
|
||||
@@ -964,19 +973,23 @@ export function ChatPage({
|
||||
}
|
||||
};
|
||||
|
||||
const onPersonaChange = (persona: Persona | null) => {
|
||||
if (persona && persona.id !== livePersona.id) {
|
||||
const onAssistantChange = (assistant: Persona | null) => {
|
||||
if (assistant && assistant.id !== liveAssistant.id) {
|
||||
// remove uploaded files
|
||||
setCurrentMessageFiles([]);
|
||||
setSelectedPersona(persona);
|
||||
setSelectedAssistant(assistant);
|
||||
textAreaRef.current?.focus();
|
||||
router.push(buildChatUrl(searchParams, null, persona.id));
|
||||
router.push(buildChatUrl(searchParams, null, assistant.id));
|
||||
}
|
||||
};
|
||||
|
||||
const handleImageUpload = (acceptedFiles: File[]) => {
|
||||
const llmAcceptsImages = checkLLMSupportsImageInput(
|
||||
...getFinalLLM(llmProviders, livePersona, llmOverrideManager.llmOverride)
|
||||
...getFinalLLM(
|
||||
llmProviders,
|
||||
liveAssistant,
|
||||
llmOverrideManager.llmOverride
|
||||
)
|
||||
);
|
||||
const imageFiles = acceptedFiles.filter((file) =>
|
||||
file.type.startsWith("image/")
|
||||
@@ -1058,23 +1071,23 @@ export function ChatPage({
|
||||
useEffect(() => {
|
||||
const includes = checkAnyAssistantHasSearch(
|
||||
messageHistory,
|
||||
availablePersonas,
|
||||
livePersona
|
||||
availableAssistants,
|
||||
liveAssistant
|
||||
);
|
||||
setRetrievalEnabled(includes);
|
||||
}, [messageHistory, availablePersonas, livePersona]);
|
||||
}, [messageHistory, availableAssistants, liveAssistant]);
|
||||
|
||||
const [retrievalEnabled, setRetrievalEnabled] = useState(() => {
|
||||
return checkAnyAssistantHasSearch(
|
||||
messageHistory,
|
||||
availablePersonas,
|
||||
livePersona
|
||||
availableAssistants,
|
||||
liveAssistant
|
||||
);
|
||||
});
|
||||
|
||||
const innerSidebarElementRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const currentPersona = selectedAssistant || livePersona;
|
||||
const currentPersona = alternativeAssistant || liveAssistant;
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
@@ -1176,21 +1189,8 @@ export function ChatPage({
|
||||
/>
|
||||
)}
|
||||
|
||||
<ConfigurationModal
|
||||
chatSessionId={chatSessionIdRef.current!}
|
||||
activeTab={configModalActiveTab}
|
||||
setActiveTab={setConfigModalActiveTab}
|
||||
onClose={() => setConfigModalActiveTab(null)}
|
||||
filterManager={filterManager}
|
||||
availableAssistants={filteredAssistants}
|
||||
selectedAssistant={livePersona}
|
||||
setSelectedAssistant={onPersonaChange}
|
||||
llmProviders={llmProviders}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
/>
|
||||
|
||||
<div className="flex h-[calc(100dvh)] flex-col w-full">
|
||||
{livePersona && (
|
||||
{liveAssistant && (
|
||||
<FunctionalHeader
|
||||
page="chat"
|
||||
setSharingModalVisible={
|
||||
@@ -1203,6 +1203,23 @@ export function ChatPage({
|
||||
currentChatSession={selectedChatSession}
|
||||
/>
|
||||
)}
|
||||
<div className="w-full flex">
|
||||
<div
|
||||
style={{ transition: "width 0.30s ease-out" }}
|
||||
className={`
|
||||
flex-none
|
||||
overflow-y-hidden
|
||||
bg-background-100
|
||||
transition-all
|
||||
bg-opacity-80
|
||||
duration-300
|
||||
ease-in-out
|
||||
h-full
|
||||
${toggledSidebar || showDocSidebar ? "w-[300px]" : "w-[0px]"}
|
||||
`}
|
||||
/>
|
||||
<ChatBanner />
|
||||
</div>
|
||||
{documentSidebarInitialWidth !== undefined ? (
|
||||
<Dropzone onDrop={handleImageUpload} noClick>
|
||||
{({ getRootProps }) => (
|
||||
@@ -1231,15 +1248,14 @@ export function ChatPage({
|
||||
ref={scrollableDivRef}
|
||||
>
|
||||
{/* ChatBanner is a custom banner that displays a admin-specified message at
|
||||
the top of the chat page. Only used in the EE version of the app. */}
|
||||
<ChatBanner />
|
||||
the top of the chat page. Oly used in the EE version of the app. */}
|
||||
|
||||
{messageHistory.length === 0 &&
|
||||
!isFetchingChatMessages &&
|
||||
!isStreaming && (
|
||||
<ChatIntro
|
||||
availableSources={finalAvailableSources}
|
||||
selectedPersona={livePersona}
|
||||
selectedPersona={liveAssistant}
|
||||
/>
|
||||
)}
|
||||
<div
|
||||
@@ -1319,7 +1335,7 @@ export function ChatPage({
|
||||
|
||||
const currentAlternativeAssistant =
|
||||
message.alternateAssistantID != null
|
||||
? availablePersonas.find(
|
||||
? availableAssistants.find(
|
||||
(persona) =>
|
||||
persona.id ==
|
||||
message.alternateAssistantID
|
||||
@@ -1342,7 +1358,7 @@ export function ChatPage({
|
||||
toggleDocumentSelectionAspects
|
||||
}
|
||||
docs={message.documents}
|
||||
currentPersona={livePersona}
|
||||
currentPersona={liveAssistant}
|
||||
alternativeAssistant={
|
||||
currentAlternativeAssistant
|
||||
}
|
||||
@@ -1352,7 +1368,7 @@ export function ChatPage({
|
||||
query={
|
||||
messageHistory[i]?.query || undefined
|
||||
}
|
||||
personaName={livePersona.name}
|
||||
personaName={liveAssistant.name}
|
||||
citedDocuments={getCitedDocumentsFromMessage(
|
||||
message
|
||||
)}
|
||||
@@ -1404,7 +1420,7 @@ export function ChatPage({
|
||||
messageIdToResend:
|
||||
previousMessage.messageId,
|
||||
queryOverride: newQuery,
|
||||
alternativeAssistant:
|
||||
alternativeAssistantOverride:
|
||||
currentAlternativeAssistant,
|
||||
});
|
||||
}
|
||||
@@ -1435,7 +1451,7 @@ export function ChatPage({
|
||||
messageIdToResend:
|
||||
previousMessage.messageId,
|
||||
forceSearch: true,
|
||||
alternativeAssistant:
|
||||
alternativeAssistantOverride:
|
||||
currentAlternativeAssistant,
|
||||
});
|
||||
} else {
|
||||
@@ -1460,9 +1476,9 @@ export function ChatPage({
|
||||
return (
|
||||
<div key={messageReactComponentKey}>
|
||||
<AIMessage
|
||||
currentPersona={livePersona}
|
||||
currentPersona={liveAssistant}
|
||||
messageId={message.messageId}
|
||||
personaName={livePersona.name}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{message.message}
|
||||
@@ -1481,13 +1497,13 @@ export function ChatPage({
|
||||
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
|
||||
>
|
||||
<AIMessage
|
||||
currentPersona={livePersona}
|
||||
currentPersona={liveAssistant}
|
||||
alternativeAssistant={
|
||||
alternativeGeneratingAssistant ??
|
||||
selectedAssistant
|
||||
alternativeAssistant
|
||||
}
|
||||
messageId={null}
|
||||
personaName={livePersona.name}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<div className="text-sm my-auto">
|
||||
<ThreeDots
|
||||
@@ -1513,7 +1529,7 @@ export function ChatPage({
|
||||
{currentPersona &&
|
||||
currentPersona.starter_messages &&
|
||||
currentPersona.starter_messages.length > 0 &&
|
||||
selectedPersona &&
|
||||
selectedAssistant &&
|
||||
messageHistory.length === 0 &&
|
||||
!isFetchingChatMessages && (
|
||||
<div
|
||||
@@ -1570,32 +1586,25 @@ export function ChatPage({
|
||||
<ChatInputBar
|
||||
showDocs={() => setDocumentSelection(true)}
|
||||
selectedDocuments={selectedDocuments}
|
||||
setSelectedAssistant={onPersonaChange}
|
||||
onSetSelectedAssistant={(
|
||||
alternativeAssistant: Persona | null
|
||||
) => {
|
||||
setSelectedAssistant(alternativeAssistant);
|
||||
}}
|
||||
alternativeAssistant={selectedAssistant}
|
||||
personas={filteredAssistants}
|
||||
// assistant stuff
|
||||
assistantOptions={filteredAssistants}
|
||||
selectedAssistant={liveAssistant}
|
||||
setSelectedAssistant={onAssistantChange}
|
||||
setAlternativeAssistant={setAlternativeAssistant}
|
||||
alternativeAssistant={alternativeAssistant}
|
||||
// end assistant stuff
|
||||
message={message}
|
||||
setMessage={setMessage}
|
||||
onSubmit={onSubmit}
|
||||
isStreaming={isStreaming}
|
||||
setIsCancelled={setIsCancelled}
|
||||
retrievalDisabled={
|
||||
!personaIncludesRetrieval(currentPersona)
|
||||
}
|
||||
filterManager={filterManager}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
selectedAssistant={livePersona}
|
||||
files={currentMessageFiles}
|
||||
setFiles={setCurrentMessageFiles}
|
||||
handleFileUpload={handleImageUpload}
|
||||
setConfigModalActiveTab={setConfigModalActiveTab}
|
||||
textAreaRef={textAreaRef}
|
||||
chatSessionId={chatSessionIdRef.current!}
|
||||
availableAssistants={availablePersonas}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -5,10 +5,10 @@ import { ChatPage } from "./ChatPage";
|
||||
import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper";
|
||||
|
||||
export default function WrappedChat({
|
||||
defaultPersonaId,
|
||||
defaultAssistantId,
|
||||
initiallyToggled,
|
||||
}: {
|
||||
defaultPersonaId?: number;
|
||||
defaultAssistantId?: number;
|
||||
initiallyToggled: boolean;
|
||||
}) {
|
||||
return (
|
||||
@@ -17,7 +17,7 @@ export default function WrappedChat({
|
||||
content={(toggledSidebar, toggle) => (
|
||||
<ChatPage
|
||||
toggle={toggle}
|
||||
defaultSelectedPersonaId={defaultPersonaId}
|
||||
defaultSelectedAssistantId={defaultAssistantId}
|
||||
toggledSidebar={toggledSidebar}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -21,7 +21,6 @@ import { IconType } from "react-icons";
|
||||
import Popup from "../../../components/popup/Popup";
|
||||
import { LlmTab } from "../modal/configuration/LlmTab";
|
||||
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
|
||||
import ChatInputAssistant from "./ChatInputAssistant";
|
||||
import { DanswerDocument } from "@/lib/search/interfaces";
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
@@ -29,7 +28,6 @@ import { Hoverable } from "@/components/Hoverable";
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
export function ChatInputBar({
|
||||
personas,
|
||||
showDocs,
|
||||
selectedDocuments,
|
||||
message,
|
||||
@@ -37,34 +35,32 @@ export function ChatInputBar({
|
||||
onSubmit,
|
||||
isStreaming,
|
||||
setIsCancelled,
|
||||
retrievalDisabled,
|
||||
filterManager,
|
||||
llmOverrideManager,
|
||||
onSetSelectedAssistant,
|
||||
selectedAssistant,
|
||||
files,
|
||||
|
||||
// assistants
|
||||
selectedAssistant,
|
||||
assistantOptions,
|
||||
setSelectedAssistant,
|
||||
setAlternativeAssistant,
|
||||
|
||||
files,
|
||||
setFiles,
|
||||
handleFileUpload,
|
||||
setConfigModalActiveTab,
|
||||
textAreaRef,
|
||||
alternativeAssistant,
|
||||
chatSessionId,
|
||||
availableAssistants,
|
||||
}: {
|
||||
showDocs: () => void;
|
||||
selectedDocuments: DanswerDocument[];
|
||||
availableAssistants: Persona[];
|
||||
onSetSelectedAssistant: (alternativeAssistant: Persona | null) => void;
|
||||
assistantOptions: Persona[];
|
||||
setAlternativeAssistant: (alternativeAssistant: Persona | null) => void;
|
||||
setSelectedAssistant: (assistant: Persona) => void;
|
||||
personas: Persona[];
|
||||
message: string;
|
||||
setMessage: (message: string) => void;
|
||||
onSubmit: () => void;
|
||||
isStreaming: boolean;
|
||||
setIsCancelled: (value: boolean) => void;
|
||||
retrievalDisabled: boolean;
|
||||
filterManager: FilterManager;
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
selectedAssistant: Persona;
|
||||
@@ -72,7 +68,6 @@ export function ChatInputBar({
|
||||
files: FileDescriptor[];
|
||||
setFiles: (files: FileDescriptor[]) => void;
|
||||
handleFileUpload: (files: File[]) => void;
|
||||
setConfigModalActiveTab: (tab: string) => void;
|
||||
textAreaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
chatSessionId?: number;
|
||||
}) {
|
||||
@@ -136,8 +131,10 @@ export function ChatInputBar({
|
||||
};
|
||||
|
||||
// Update selected persona
|
||||
const updateCurrentPersona = (persona: Persona) => {
|
||||
onSetSelectedAssistant(persona.id == selectedAssistant.id ? null : persona);
|
||||
const updatedTaggedAssistant = (assistant: Persona) => {
|
||||
setAlternativeAssistant(
|
||||
assistant.id == selectedAssistant.id ? null : assistant
|
||||
);
|
||||
hideSuggestions();
|
||||
setMessage("");
|
||||
};
|
||||
@@ -160,8 +157,8 @@ export function ChatInputBar({
|
||||
}
|
||||
};
|
||||
|
||||
const filteredPersonas = personas.filter((persona) =>
|
||||
persona.name.toLowerCase().startsWith(
|
||||
const assistantTagOptions = assistantOptions.filter((assistant) =>
|
||||
assistant.name.toLowerCase().startsWith(
|
||||
message
|
||||
.slice(message.lastIndexOf("@") + 1)
|
||||
.split(/\s/)[0]
|
||||
@@ -174,18 +171,18 @@ export function ChatInputBar({
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (
|
||||
showSuggestions &&
|
||||
filteredPersonas.length > 0 &&
|
||||
assistantTagOptions.length > 0 &&
|
||||
(e.key === "Tab" || e.key == "Enter")
|
||||
) {
|
||||
e.preventDefault();
|
||||
if (assistantIconIndex == filteredPersonas.length) {
|
||||
if (assistantIconIndex == assistantTagOptions.length) {
|
||||
window.open("/assistants/new", "_blank");
|
||||
hideSuggestions();
|
||||
setMessage("");
|
||||
} else {
|
||||
const option =
|
||||
filteredPersonas[assistantIconIndex >= 0 ? assistantIconIndex : 0];
|
||||
updateCurrentPersona(option);
|
||||
assistantTagOptions[assistantIconIndex >= 0 ? assistantIconIndex : 0];
|
||||
updatedTaggedAssistant(option);
|
||||
}
|
||||
}
|
||||
if (!showSuggestions) {
|
||||
@@ -195,7 +192,7 @@ export function ChatInputBar({
|
||||
if (e.key === "ArrowDown") {
|
||||
e.preventDefault();
|
||||
setAssistantIconIndex((assistantIconIndex) =>
|
||||
Math.min(assistantIconIndex + 1, filteredPersonas.length)
|
||||
Math.min(assistantIconIndex + 1, assistantTagOptions.length)
|
||||
);
|
||||
} else if (e.key === "ArrowUp") {
|
||||
e.preventDefault();
|
||||
@@ -219,35 +216,36 @@ export function ChatInputBar({
|
||||
mx-auto
|
||||
"
|
||||
>
|
||||
{showSuggestions && filteredPersonas.length > 0 && (
|
||||
{showSuggestions && assistantTagOptions.length > 0 && (
|
||||
<div
|
||||
ref={suggestionsRef}
|
||||
className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full"
|
||||
>
|
||||
<div className="rounded-lg py-1.5 bg-background border border-border-medium shadow-lg mx-2 px-1.5 mt-2 rounded z-10">
|
||||
{filteredPersonas.map((currentPersona, index) => (
|
||||
{assistantTagOptions.map((currentAssistant, index) => (
|
||||
<button
|
||||
key={index}
|
||||
className={`px-2 ${
|
||||
assistantIconIndex == index && "bg-hover-lightish"
|
||||
} rounded rounded-lg content-start flex gap-x-1 py-2 w-full hover:bg-hover-lightish cursor-pointer`}
|
||||
onClick={() => {
|
||||
updateCurrentPersona(currentPersona);
|
||||
updatedTaggedAssistant(currentAssistant);
|
||||
}}
|
||||
>
|
||||
<p className="font-bold">{currentPersona.name}</p>
|
||||
<p className="font-bold">{currentAssistant.name}</p>
|
||||
<p className="line-clamp-1">
|
||||
{currentPersona.id == selectedAssistant.id &&
|
||||
{currentAssistant.id == selectedAssistant.id &&
|
||||
"(default) "}
|
||||
{currentPersona.description}
|
||||
{currentAssistant.description}
|
||||
</p>
|
||||
</button>
|
||||
))}
|
||||
<a
|
||||
key={filteredPersonas.length}
|
||||
key={assistantTagOptions.length}
|
||||
target="_blank"
|
||||
className={`${
|
||||
assistantIconIndex == filteredPersonas.length && "bg-hover"
|
||||
assistantIconIndex == assistantTagOptions.length &&
|
||||
"bg-hover"
|
||||
} rounded rounded-lg px-3 flex gap-x-1 py-2 w-full items-center hover:bg-hover-lightish cursor-pointer"`}
|
||||
href="/assistants/new"
|
||||
>
|
||||
@@ -301,7 +299,7 @@ export function ChatInputBar({
|
||||
|
||||
<Hoverable
|
||||
icon={FiX}
|
||||
onClick={() => onSetSelectedAssistant(null)}
|
||||
onClick={() => setAlternativeAssistant(null)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -409,7 +407,7 @@ export function ChatInputBar({
|
||||
removePadding
|
||||
content={(close) => (
|
||||
<AssistantsTab
|
||||
availableAssistants={availableAssistants}
|
||||
availableAssistants={assistantOptions}
|
||||
llmProviders={llmProviders}
|
||||
selectedAssistant={selectedAssistant}
|
||||
onSelect={(assistant) => {
|
||||
|
||||
@@ -505,7 +505,7 @@ export function removeMessage(
|
||||
|
||||
export function checkAnyAssistantHasSearch(
|
||||
messageHistory: Message[],
|
||||
availablePersonas: Persona[],
|
||||
availableAssistants: Persona[],
|
||||
livePersona: Persona
|
||||
): boolean {
|
||||
const response =
|
||||
@@ -516,8 +516,8 @@ export function checkAnyAssistantHasSearch(
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
const alternateAssistant = availablePersonas.find(
|
||||
(persona) => persona.id === message.alternateAssistantID
|
||||
const alternateAssistant = availableAssistants.find(
|
||||
(assistant) => assistant.id === message.alternateAssistantID
|
||||
);
|
||||
return alternateAssistant
|
||||
? personaIncludesRetrieval(alternateAssistant)
|
||||
|
||||
@@ -58,7 +58,6 @@ import { ValidSources } from "@/lib/types";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { useMouseTracking } from "./hooks";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { getTitleFromDocument } from "@/lib/sources";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -445,7 +444,8 @@ export const AIMessage = ({
|
||||
>
|
||||
<Citation link={doc.link} index={ind + 1} />
|
||||
<p className="shrink truncate ellipsis break-all ">
|
||||
{getTitleFromDocument(doc)}
|
||||
{doc.semantic_identifier ||
|
||||
doc.document_id}
|
||||
</p>
|
||||
<div className="ml-auto flex-none">
|
||||
{doc.is_internet ? (
|
||||
@@ -798,7 +798,7 @@ export const HumanMessage = ({
|
||||
!isEditing &&
|
||||
(!files || files.length === 0)
|
||||
) && "ml-auto"
|
||||
} relative max-w-[70%] mb-auto rounded-3xl bg-user px-5 py-2.5`}
|
||||
} relative max-w-[70%] mb-auto whitespace-break-spaces rounded-3xl bg-user px-5 py-2.5`}
|
||||
>
|
||||
{content}
|
||||
</div>
|
||||
|
||||
@@ -62,11 +62,10 @@ export function AssistantsTab({
|
||||
toolName = "Image Generation";
|
||||
toolIcon = <FiImage className="mr-1 my-auto" />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Bubble key={tool.id} isSelected={false}>
|
||||
<div className="flex flex-row gap-1">
|
||||
{toolIcon}
|
||||
<div className="flex line-wrap break-all flex-row gap-1">
|
||||
<div className="flex-none my-auto">{toolIcon}</div>
|
||||
{toolName}
|
||||
</div>
|
||||
</Bubble>
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { useEffect } from "react";
|
||||
import { Modal } from "../../../../components/Modal";
|
||||
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
|
||||
import { FiltersTab } from "./FiltersTab";
|
||||
import { FiCpu, FiFilter, FiX } from "react-icons/fi";
|
||||
import { IconType } from "react-icons";
|
||||
import { FaBrain } from "react-icons/fa";
|
||||
import { AssistantsTab } from "./AssistantsTab";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { LlmTab } from "./LlmTab";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
|
||||
import { AssistantsIcon, IconProps } from "@/components/icons/icons";
|
||||
|
||||
const TabButton = ({
|
||||
label,
|
||||
icon: Icon,
|
||||
isActive,
|
||||
onClick,
|
||||
}: {
|
||||
label: string;
|
||||
icon: IconType;
|
||||
isActive: boolean;
|
||||
onClick: () => void;
|
||||
}) => (
|
||||
<button
|
||||
onClick={onClick}
|
||||
className={`
|
||||
pb-4
|
||||
pt-6
|
||||
px-2
|
||||
text-emphasis
|
||||
font-bold
|
||||
${isActive ? "border-b-2 border-accent" : ""}
|
||||
hover:bg-hover-light
|
||||
hover:text-strong
|
||||
transition
|
||||
duration-200
|
||||
ease-in-out
|
||||
flex
|
||||
`}
|
||||
>
|
||||
<Icon className="inline-block mr-2 my-auto" size="16" />
|
||||
|
||||
<p className="my-auto">{label}</p>
|
||||
</button>
|
||||
);
|
||||
|
||||
export function ConfigurationModal({
|
||||
activeTab,
|
||||
setActiveTab,
|
||||
onClose,
|
||||
availableAssistants,
|
||||
selectedAssistant,
|
||||
setSelectedAssistant,
|
||||
filterManager,
|
||||
llmProviders,
|
||||
llmOverrideManager,
|
||||
chatSessionId,
|
||||
}: {
|
||||
activeTab: string | null;
|
||||
setActiveTab: (tab: string | null) => void;
|
||||
onClose: () => void;
|
||||
availableAssistants: Persona[];
|
||||
selectedAssistant: Persona;
|
||||
setSelectedAssistant: (assistant: Persona) => void;
|
||||
filterManager: FilterManager;
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
chatSessionId?: number;
|
||||
}) {
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === "Escape") {
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("keydown", handleKeyDown);
|
||||
return () => {
|
||||
document.removeEventListener("keydown", handleKeyDown);
|
||||
};
|
||||
}, [onClose]);
|
||||
|
||||
if (!activeTab) return null;
|
||||
|
||||
return (
|
||||
<Modal
|
||||
onOutsideClick={onClose}
|
||||
noPadding
|
||||
className="
|
||||
w-4/6
|
||||
h-4/6
|
||||
flex
|
||||
flex-col
|
||||
"
|
||||
>
|
||||
<div className="rounded flex flex-col overflow-hidden">
|
||||
<div className="mb-4">
|
||||
<div className="flex border-b border-border bg-background-emphasis">
|
||||
<div className="flex px-6 gap-x-2">
|
||||
<TabButton
|
||||
label="Assistants"
|
||||
icon={FaBrain}
|
||||
isActive={activeTab === "assistants"}
|
||||
onClick={() => setActiveTab("assistants")}
|
||||
/>
|
||||
<TabButton
|
||||
label="Models"
|
||||
icon={FiCpu}
|
||||
isActive={activeTab === "llms"}
|
||||
onClick={() => setActiveTab("llms")}
|
||||
/>
|
||||
<TabButton
|
||||
label="Filters"
|
||||
icon={FiFilter}
|
||||
isActive={activeTab === "filters"}
|
||||
onClick={() => setActiveTab("filters")}
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
className="
|
||||
ml-auto
|
||||
px-1
|
||||
py-1
|
||||
text-xs
|
||||
font-medium
|
||||
rounded
|
||||
hover:bg-hover
|
||||
focus:outline-none
|
||||
focus:ring-2
|
||||
focus:ring-offset-2
|
||||
focus:ring-subtle
|
||||
flex
|
||||
items-center
|
||||
h-fit
|
||||
my-auto
|
||||
mr-5
|
||||
"
|
||||
onClick={onClose}
|
||||
>
|
||||
<FiX size={24} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col overflow-y-auto">
|
||||
<div className="px-8 pt-4">
|
||||
{activeTab === "filters" && (
|
||||
<FiltersTab filterManager={filterManager} />
|
||||
)}
|
||||
|
||||
{activeTab === "llms" && (
|
||||
<LlmTab
|
||||
chatSessionId={chatSessionId}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
currentAssistant={selectedAssistant}
|
||||
/>
|
||||
)}
|
||||
|
||||
{activeTab === "assistants" && (
|
||||
<div>
|
||||
<AssistantsTab
|
||||
availableAssistants={availableAssistants}
|
||||
llmProviders={llmProviders}
|
||||
selectedAssistant={selectedAssistant}
|
||||
onSelect={(assistant) => {
|
||||
setSelectedAssistant(assistant);
|
||||
onClose();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -35,7 +35,7 @@ export default async function Page({
|
||||
folders,
|
||||
toggleSidebar,
|
||||
openedFolders,
|
||||
defaultPersonaId,
|
||||
defaultAssistantId,
|
||||
finalDocumentSidebarInitialWidth,
|
||||
shouldShowWelcomeModal,
|
||||
shouldDisplaySourcesIncompleteModal,
|
||||
@@ -58,7 +58,7 @@ export default async function Page({
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets: documentSets,
|
||||
availablePersonas: assistants,
|
||||
availableAssistants: assistants,
|
||||
availableTags: tags,
|
||||
llmProviders,
|
||||
folders,
|
||||
@@ -66,7 +66,7 @@ export default async function Page({
|
||||
}}
|
||||
>
|
||||
<WrappedChat
|
||||
defaultPersonaId={defaultPersonaId}
|
||||
defaultAssistantId={defaultAssistantId}
|
||||
initiallyToggled={toggleSidebar}
|
||||
/>
|
||||
</ChatProvider>
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { BasicSelectable } from "@/components/BasicClickable";
|
||||
import { AssistantsIcon } from "@/components/icons/icons";
|
||||
import { User } from "@/lib/types";
|
||||
import { Text } from "@tremor/react";
|
||||
import Link from "next/link";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
import { FiEdit2 } from "react-icons/fi";
|
||||
|
||||
function AssistantDisplay({
|
||||
persona,
|
||||
onSelect,
|
||||
user,
|
||||
}: {
|
||||
persona: Persona;
|
||||
onSelect: (persona: Persona) => void;
|
||||
user: User | null;
|
||||
}) {
|
||||
const isEditable =
|
||||
(!user || user.id === persona.owner?.id) &&
|
||||
!persona.default_persona &&
|
||||
(!persona.is_public || !user || user.role === "admin");
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<div className="w-full" onClick={() => onSelect(persona)}>
|
||||
<BasicSelectable selected={false} fullWidth>
|
||||
<div className="flex">
|
||||
<div className="truncate w-48 3xl:w-56 flex">
|
||||
<AssistantsIcon className="mr-2 my-auto" size={16} />{" "}
|
||||
{persona.name}
|
||||
</div>
|
||||
</div>
|
||||
</BasicSelectable>
|
||||
</div>
|
||||
{isEditable && (
|
||||
<div className="pl-2 my-auto">
|
||||
<Link href={`/assistants/edit/${persona.id}`}>
|
||||
<FiEdit2
|
||||
className="my-auto ml-auto hover:bg-hover p-0.5"
|
||||
size={20}
|
||||
/>
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function AssistantsTab({
|
||||
personas,
|
||||
onPersonaChange,
|
||||
user,
|
||||
}: {
|
||||
personas: Persona[];
|
||||
onPersonaChange: (persona: Persona | null) => void;
|
||||
user: User | null;
|
||||
}) {
|
||||
const globalAssistants = personas.filter((persona) => persona.is_public);
|
||||
const personalAssistants = personas.filter(
|
||||
(persona) =>
|
||||
(!user || persona.users.some((u) => u.id === user.id)) &&
|
||||
!persona.is_public
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mt-4 pb-1 overflow-y-auto h-full flex flex-col gap-y-1">
|
||||
<Text className="mx-3 text-xs mb-4">
|
||||
Select an Assistant below to begin a new chat with them!
|
||||
</Text>
|
||||
|
||||
<div className="mx-3">
|
||||
{globalAssistants.length > 0 && (
|
||||
<>
|
||||
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 font-bold">
|
||||
Global
|
||||
</div>
|
||||
{globalAssistants.map((persona) => {
|
||||
return (
|
||||
<AssistantDisplay
|
||||
key={persona.id}
|
||||
persona={persona}
|
||||
onSelect={onPersonaChange}
|
||||
user={user}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
{personalAssistants.length > 0 && (
|
||||
<>
|
||||
<div className="text-xs text-subtle flex pb-0.5 ml-1 mb-1.5 mt-5 font-bold">
|
||||
Personal
|
||||
</div>
|
||||
{personalAssistants.map((persona) => {
|
||||
return (
|
||||
<AssistantDisplay
|
||||
key={persona.id}
|
||||
persona={persona}
|
||||
onSelect={onPersonaChange}
|
||||
user={user}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -99,22 +99,28 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
h-screen
|
||||
transition-transform`}
|
||||
>
|
||||
<div className="ml-4 mr-3 flex flex gap-x-1 items-center mt-2 my-auto text-text-700 text-xl">
|
||||
<div className="mr-1 my-auto h-6 w-6">
|
||||
<div className="max-w-full ml-3 mr-3 mt-2 flex flex gap-x-1 items-center my-auto text-text-700 text-xl">
|
||||
<div className="mr-1 mb-auto h-6 w-6">
|
||||
<Logo height={24} width={24} />
|
||||
</div>
|
||||
|
||||
<div className="invisible">
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
|
||||
<div>
|
||||
<HeaderTitle>
|
||||
{enterpriseSettings.application_name}
|
||||
</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-subtle">Powered by Danswer</p>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<HeaderTitle>Danswer</HeaderTitle>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{toggleSidebar && (
|
||||
<Tooltip delayDuration={1000} content={`${commandSymbol}E show`}>
|
||||
<button className="ml-auto" onClick={toggleSidebar}>
|
||||
<button className="mb-auto ml-auto" onClick={toggleSidebar}>
|
||||
{!toggled ? <RightToLineIcon /> : <LefToLineIcon />}
|
||||
</button>
|
||||
</Tooltip>
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { HeaderTitle } from "@/components/header/Header";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
|
||||
import { useContext } from "react";
|
||||
|
||||
export default function FixedLogo() {
|
||||
@@ -11,17 +12,24 @@ export default function FixedLogo() {
|
||||
const enterpriseSettings = combinedSettings?.enterpriseSettings;
|
||||
|
||||
return (
|
||||
<div className="fixed flex z-40 left-4 top-2">
|
||||
{" "}
|
||||
<a href="/chat" className="ml-7 text-text-700 text-xl">
|
||||
<div>
|
||||
<div className="absolute flex z-40 left-2.5 top-2">
|
||||
<div className="max-w-[200px] flex gap-x-1 my-auto">
|
||||
<div className="flex-none invisible mb-auto">
|
||||
<Logo />
|
||||
</div>
|
||||
<div className="">
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
|
||||
<div>
|
||||
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-subtle">Powered by Danswer</p>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<HeaderTitle>Danswer</HeaderTitle>
|
||||
)}
|
||||
</div>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -10,12 +10,11 @@ const ToggleSwitch = () => {
|
||||
const commandSymbol = KeyboardSymbol();
|
||||
const pathname = usePathname();
|
||||
const router = useRouter();
|
||||
|
||||
const [activeTab, setActiveTab] = useState(() => {
|
||||
if (typeof window !== "undefined") {
|
||||
return localStorage.getItem("activeTab") || "chat";
|
||||
}
|
||||
return "chat";
|
||||
return pathname == "/search" ? "search" : "chat";
|
||||
});
|
||||
|
||||
const [isInitialLoad, setIsInitialLoad] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -168,7 +168,8 @@ export default async function Home() {
|
||||
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
|
||||
) &&
|
||||
!shouldDisplayNoSourcesModal &&
|
||||
!shouldShowWelcomeModal;
|
||||
!shouldShowWelcomeModal &&
|
||||
(!user || user.role == "admin");
|
||||
|
||||
const sidebarToggled = cookies().get(SIDEBAR_TOGGLED_COOKIE_NAME);
|
||||
const agenticSearchToggle = cookies().get(AGENTIC_SEARCH_TYPE_COOKIE_NAME);
|
||||
@@ -178,7 +179,7 @@ export default async function Home() {
|
||||
: false;
|
||||
|
||||
const agenticSearchEnabled = agenticSearchToggle
|
||||
? agenticSearchToggle.value.toLocaleLowerCase() == "true" || true
|
||||
? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false
|
||||
: false;
|
||||
|
||||
return (
|
||||
|
||||
26
web/src/components/AdvancedOptionsToggle.tsx
Normal file
26
web/src/components/AdvancedOptionsToggle.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import React from "react";
|
||||
import { Button } from "@tremor/react";
|
||||
import { FiChevronDown, FiChevronRight } from "react-icons/fi";
|
||||
|
||||
interface AdvancedOptionsToggleProps {
|
||||
showAdvancedOptions: boolean;
|
||||
setShowAdvancedOptions: (show: boolean) => void;
|
||||
}
|
||||
|
||||
export function AdvancedOptionsToggle({
|
||||
showAdvancedOptions,
|
||||
setShowAdvancedOptions,
|
||||
}: AdvancedOptionsToggleProps) {
|
||||
return (
|
||||
<Button
|
||||
type="button"
|
||||
variant="light"
|
||||
size="xs"
|
||||
icon={showAdvancedOptions ? FiChevronDown : FiChevronRight}
|
||||
onClick={() => setShowAdvancedOptions(!showAdvancedOptions)}
|
||||
className="mb-4 text-xs text-text-500 hover:text-text-400"
|
||||
>
|
||||
Advanced Options
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -31,7 +31,10 @@ export function Logo({
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ height, width }} className={`relative ${className}`}>
|
||||
<div
|
||||
style={{ height, width }}
|
||||
className={`flex-none relative ${className}`}
|
||||
>
|
||||
{/* TODO: figure out how to use Next Image here */}
|
||||
<img
|
||||
src="/api/enterprise-settings/logo"
|
||||
|
||||
@@ -91,7 +91,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<div className="h-screen overflow-y-hidden">
|
||||
<div className="flex h-full">
|
||||
<div className="w-64 z-20 bg-background-100 pt-4 pb-8 h-full border-r border-border miniscroll overflow-auto">
|
||||
<div className="w-64 z-20 bg-background-100 pt-3 pb-8 h-full border-r border-border miniscroll overflow-auto">
|
||||
<AdminSidebar
|
||||
collections={[
|
||||
{
|
||||
|
||||
@@ -28,9 +28,9 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
|
||||
|
||||
return (
|
||||
<aside className="pl-0">
|
||||
<nav className="space-y-2 pl-4">
|
||||
<div className="pb-12 flex">
|
||||
<div className="fixed left-0 top-0 py-2 pl-4 bg-background-100 w-[200px]">
|
||||
<nav className="space-y-2 pl-2">
|
||||
<div className="mb-4 flex">
|
||||
<div className="bg-background-100">
|
||||
<Link
|
||||
className="flex flex-col"
|
||||
href={
|
||||
@@ -39,8 +39,8 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
|
||||
: "/search"
|
||||
}
|
||||
>
|
||||
<div className="flex gap-x-1 my-auto">
|
||||
<div className="my-auto">
|
||||
<div className="max-w-[200px] flex gap-x-1 my-auto">
|
||||
<div className="flex-none mb-auto">
|
||||
<Logo />
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
@@ -50,7 +50,7 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
|
||||
{enterpriseSettings.application_name}
|
||||
</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-subtle -mt-1.5">
|
||||
<p className="text-xs text-subtle">
|
||||
Powered by Danswer
|
||||
</p>
|
||||
)}
|
||||
@@ -63,14 +63,16 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
<Link href={"/chat"}>
|
||||
<button className="text-sm block w-48 py-2.5 flex px-2 text-left bg-background-200 hover:bg-background-200/80 cursor-pointer rounded">
|
||||
<BackIcon size={20} className="text-neutral" />
|
||||
<p className="ml-1">Back to Danswer</p>
|
||||
</button>
|
||||
</Link>
|
||||
<div className="px-3">
|
||||
<Link href={"/chat"}>
|
||||
<button className="text-sm block w-48 py-2.5 flex px-2 text-left bg-background-200 hover:bg-background-200/80 cursor-pointer rounded">
|
||||
<BackIcon size={20} className="text-neutral" />
|
||||
<p className="ml-1">Back to Danswer</p>
|
||||
</button>
|
||||
</Link>
|
||||
</div>
|
||||
{collections.map((collection, collectionInd) => (
|
||||
<div key={collectionInd}>
|
||||
<div className="px-3" key={collectionInd}>
|
||||
<h2 className="text-xs text-strong font-bold pb-2">
|
||||
<div>{collection.name}</div>
|
||||
</h2>
|
||||
|
||||
@@ -59,18 +59,23 @@ export default function FunctionalHeader({
|
||||
<div className="pb-6 left-0 sticky top-0 z-10 w-full relative flex">
|
||||
<div className="mt-2 mx-4 text-text-700 flex w-full">
|
||||
<div className="absolute z-[100] my-auto flex items-center text-xl font-bold">
|
||||
<FiSidebar size={20} />
|
||||
<div className="ml-2 text-text-700 text-xl">
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
|
||||
) : (
|
||||
<HeaderTitle>Danswer</HeaderTitle>
|
||||
)}
|
||||
<div className="pt-[2px] mb-auto">
|
||||
<FiSidebar size={20} />
|
||||
</div>
|
||||
<div className="break-words inline-block w-fit ml-2 text-text-700 text-xl">
|
||||
<div className="max-w-[200px]">
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
|
||||
) : (
|
||||
<HeaderTitle>Danswer</HeaderTitle>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{page == "chat" && (
|
||||
<Tooltip delayDuration={1000} content={`${commandSymbol}U`}>
|
||||
<Link
|
||||
className="mb-auto pt-[2px]"
|
||||
href={
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
|
||||
@@ -79,10 +84,9 @@ export default function FunctionalHeader({
|
||||
: "")
|
||||
}
|
||||
>
|
||||
<NewChatIcon
|
||||
size={20}
|
||||
className="ml-2 my-auto cursor-pointer text-text-700 hover:text-text-600 transition-colors duration-300"
|
||||
/>
|
||||
<div className="cursor-pointer ml-2 flex-none text-text-700 hover:text-text-600 transition-colors duration-300">
|
||||
<NewChatIcon size={20} className="" />
|
||||
</div>
|
||||
</Link>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
@@ -18,7 +18,7 @@ interface ChatContextProps {
|
||||
chatSessions: ChatSession[];
|
||||
availableSources: ValidSources[];
|
||||
availableDocumentSets: DocumentSet[];
|
||||
availablePersonas: Persona[];
|
||||
availableAssistants: Persona[];
|
||||
availableTags: Tag[];
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
folders: Folder[];
|
||||
|
||||
@@ -11,7 +11,11 @@ import { Logo } from "../Logo";
|
||||
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
|
||||
|
||||
export function HeaderTitle({ children }: { children: JSX.Element | string }) {
|
||||
return <h1 className="flex text-2xl text-strong font-bold">{children}</h1>;
|
||||
return (
|
||||
<h1 className="flex text-2xl text-strong leading-none font-bold">
|
||||
{children}
|
||||
</h1>
|
||||
);
|
||||
}
|
||||
|
||||
interface HeaderProps {
|
||||
@@ -36,8 +40,8 @@ export function Header({ user, page }: HeaderProps) {
|
||||
settings && settings.default_page === "chat" ? "/chat" : "/search"
|
||||
}
|
||||
>
|
||||
<div className="flex my-auto">
|
||||
<div className="mr-1 my-auto">
|
||||
<div className="max-w-[200px] bg-black flex my-auto">
|
||||
<div className="mr-1 mb-auto">
|
||||
<Logo />
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
@@ -47,9 +51,7 @@ export function Header({ user, page }: HeaderProps) {
|
||||
{enterpriseSettings.application_name}
|
||||
</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-subtle -mt-1.5">
|
||||
Powered by Danswer
|
||||
</p>
|
||||
<p className="text-xs text-subtle">Powered by Danswer</p>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
|
||||
@@ -19,7 +19,7 @@ export function Citation({
|
||||
return (
|
||||
<CustomTooltip
|
||||
citation
|
||||
content={<p className="inline-block p-0 m-0 truncate">{link}</p>}
|
||||
content={<div className="inline-block p-0 m-0 truncate">{link}</div>}
|
||||
>
|
||||
<a
|
||||
onClick={() => (link ? window.open(link, "_blank") : undefined)}
|
||||
|
||||
@@ -30,7 +30,6 @@ import { fetchAssistantsSS } from "../assistants/fetchAssistantsSS";
|
||||
interface FetchChatDataResult {
|
||||
user: User | null;
|
||||
chatSessions: ChatSession[];
|
||||
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
availableSources: ValidSources[];
|
||||
documentSets: DocumentSet[];
|
||||
@@ -39,7 +38,7 @@ interface FetchChatDataResult {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
folders: Folder[];
|
||||
openedFolders: Record<string, boolean>;
|
||||
defaultPersonaId?: number;
|
||||
defaultAssistantId?: number;
|
||||
toggleSidebar: boolean;
|
||||
finalDocumentSidebarInitialWidth?: number;
|
||||
shouldShowWelcomeModal: boolean;
|
||||
@@ -150,9 +149,9 @@ export async function fetchChatData(searchParams: {
|
||||
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
|
||||
}
|
||||
|
||||
const defaultPersonaIdRaw = searchParams["assistantId"];
|
||||
const defaultPersonaId = defaultPersonaIdRaw
|
||||
? parseInt(defaultPersonaIdRaw)
|
||||
const defaultAssistantIdRaw = searchParams["assistantId"];
|
||||
const defaultAssistantId = defaultAssistantIdRaw
|
||||
? parseInt(defaultAssistantIdRaw)
|
||||
: undefined;
|
||||
|
||||
const documentSidebarCookieInitialWidth = cookies().get(
|
||||
@@ -178,7 +177,8 @@ export async function fetchChatData(searchParams: {
|
||||
!shouldShowWelcomeModal &&
|
||||
!ccPairs.some(
|
||||
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
|
||||
);
|
||||
) &&
|
||||
(!user || user.role == "admin");
|
||||
|
||||
// if no connectors are setup, only show personas that are pure
|
||||
// passthrough and don't do any retrieval
|
||||
@@ -209,7 +209,7 @@ export async function fetchChatData(searchParams: {
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
defaultPersonaId,
|
||||
defaultAssistantId,
|
||||
finalDocumentSidebarInitialWidth,
|
||||
toggleSidebar,
|
||||
shouldShowWelcomeModal,
|
||||
|
||||
@@ -301,9 +301,3 @@ function stripTrailingSlash(str: string) {
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
export const getTitleFromDocument = (document: DanswerDocument) => {
|
||||
return stripTrailingSlash(document.document_id).split("/")[
|
||||
stripTrailingSlash(document.document_id).split("/").length - 1
|
||||
];
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user