Compare commits

...

21 Commits

Author SHA1 Message Date
rkuo-danswer
e662e3b57d clarify ssl cert reqs (#2494)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-18 05:35:57 +00:00
pablodanswer
2073820e33 Update default assistants to all visible (#2490)
* update default assistants to all visible

* update with catch-all

* minor update

* update
2024-09-18 02:08:11 +00:00
Chris Weaver
5f25b243c5 Add back llm_chunks_indices (#2491) 2024-09-18 01:21:31 +00:00
pablodanswer
a9427f190a Extend time range (contributor submission) (#2484)
* added new options for time range; removed duplicated code

* refactor + remove unused code

---------

Co-authored-by: Zoltan Szabo <zoltan.szabo@eaudeweb.ro>
2024-09-17 22:36:25 +00:00
pablodanswer
18fbe9d7e8 Warn users of gpu-sensitive operation (#2488)
* warn users of gpu-sensitive operation

* update copy
2024-09-17 21:59:43 +00:00
Chris Weaver
75c9b1cafe Fix concatenate string with toolcallkickoff issue (#2487) 2024-09-17 21:25:06 +00:00
rkuo-danswer
632a8f700b Feature/celery backend db number (#2475)
* use separate database number for celery result backend

* add comments

* add env var for celery's result_expires

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-17 21:06:36 +00:00
pablodanswer
cd58c96014 Memoize AI message component (#2483)
* memoize AI message component

* rename memoized file

* remove "zz"

* update name

* memoize for coverage

* add display name
2024-09-17 18:47:23 +00:00
pablodanswer
c5032d25c9 Minor clarity update for connectors (#2480) 2024-09-17 10:25:39 -07:00
pablodanswer
72acde6fd4 Handle tool errors in display properly (can show valueError to user) (#2481)
* handle tool errors in display properly (can show valueerrors to user)

* update for clarity
2024-09-17 17:08:46 +00:00
rkuo-danswer
5596a68d08 harden migration (#2476)
* harden migration

* remove duplicate line
2024-09-17 16:44:53 +00:00
Weves
5b18409c89 Change user-message to user-prompt 2024-09-16 21:53:27 -07:00
Chris Weaver
84272af5ac Add back scrolling to ExceptionTraceModal (#2473) 2024-09-17 02:25:53 +00:00
pablodanswer
6bef70c8b7 ensure disabled gets propagated 2024-09-16 19:27:31 -07:00
pablodanswer
7f7559e3d2 Allow users to share assistants (#2434)
* enable assistant sharing

* functional

* remove logs

* revert ports

* remove accidental update

* minor updates to copy

* update formatting

* update for merge queue
2024-09-17 01:35:29 +00:00
Chris Weaver
7ba829a585 Add top_documents to APIs (#2469)
* Add top_documents

* Fix test

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-09-16 23:48:33 +00:00
trial-danswer
8b2ecb4eab EE movement followup for Standard Answers (#2467)
* Move StandardAnswer to EE section of danswer/db/models

* Move StandardAnswer DB layer to EE

* Add EERequiredError for distinct error handling here

* Handle EE fallback for slack bot config

* Migrate all standard answer models to ee

* Flagging categories for removal

* Add missing versioned impl for update_slack_bot_config

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-16 22:05:53 +00:00
pablodanswer
2dd3870504 Add ability to specify persona in API request (#2302)
* persona

* all prepared excluding configuration

* more sensical model structure

* update tstream

* type updates

* rm

* quick and simple updates

* minor updates

* te

* ensure typing + naming

* remove old todo + rebase update

* remove unnecessary check
2024-09-16 21:31:01 +00:00
pablodanswer
df464fc54b Allow for CORS Origin Setting (#2449)
* allow setting of CORS origin

* simplify

* add environment variable + rename

* slightly more efficient

* simplify so mypy doens't complain

* temp

* go back to my preferred formatting
2024-09-16 18:54:36 +00:00
pablodanswer
96b98fbc4a Make it impossible to switch to non-image (#2440)
* make it impossible to switch to non-image

* revert ports

* proper provider support

* remove unused imports

* minor rename

* simplify interface

* remove logs
2024-09-16 18:35:40 +00:00
trial-danswer
66cf67d04d hotfix: sqlalchemy default -> server_default (#2442)
Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-16 17:49:01 +00:00
76 changed files with 1540 additions and 1316 deletions

View File

@@ -0,0 +1,64 @@
"""server default chosen assistants
Revision ID: 35e6853a51d5
Revises: c99d76fcd298
Create Date: 2024-09-13 13:20:32.885317
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "35e6853a51d5"
down_revision = "c99d76fcd298"
branch_labels = None
depends_on = None
DEFAULT_ASSISTANTS = [-2, -1, 0]
def upgrade() -> None:
# Step 1: Update any NULL values to the default value
# This upgrades existing users without ordered assistant
# to have default assistants set to visible assistants which are
# accessible by them.
op.execute(
"""
UPDATE "user" u
SET chosen_assistants = (
SELECT jsonb_agg(
p.id ORDER BY
COALESCE(p.display_priority, 2147483647) ASC,
p.id ASC
)
FROM persona p
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
WHERE p.is_visible = true
AND (p.is_public = true OR pu.user_id IS NOT NULL)
)
WHERE chosen_assistants IS NULL
OR chosen_assistants = 'null'
OR jsonb_typeof(chosen_assistants) = 'null'
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
"""
)
# Step 2: Alter the column to make it non-nullable
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
)
def downgrade() -> None:
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
server_default=None,
)

View File

@@ -0,0 +1,31 @@
"""add nullable to persona id in Chat Session
Revision ID: c99d76fcd298
Revises: 5c7fdadae813
Create Date: 2024-07-09 19:27:01.579697
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)
def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)

View File

@@ -19,7 +19,9 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
sa.Column(
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
# ### end Alembic commands ###

View File

@@ -134,7 +134,7 @@ class RedisDocumentSet(RedisObjectHelper):
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
@@ -189,7 +189,7 @@ class RedisUserGroup(RedisObjectHelper):
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
@@ -256,7 +256,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"

View File

@@ -1,5 +1,7 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
@@ -27,7 +29,7 @@ if REDIS_SSL:
# example celery_broker_url: "redis://:password@localhost:6379/15"
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
# however, prefetching is bad when tasks are lengthy as those tasks
@@ -42,3 +44,33 @@ broker_transport_options = {
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True
# It's possible we don't even need celery's result backend, in which case all of the optimization below
# might be irrelevant
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
# Option 0: Defaults (json serializer, no compression)
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
# Option 1: Reduces generator task result sizes by roughly 20%
# task_compression = "bzip2"
# task_serializer = "pickle"
# result_compression = "bzip2"
# result_serializer = "pickle"
# accept_content=["pickle"]
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
# can be large. small tasks change very little
# def pickle_bz2_encoder(data):
# return bz2.compress(pickle.dumps(data))
# def pickle_bz2_decoder(data):
# return pickle.loads(bz2.decompress(data))
# from kombu import serialization # To register custom serialization with Celery/Kombu
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
# task_serializer = "pickle-bzip2"
# result_serializer = "pickle-bzip2"
# accept_content=["pickle", "pickle-bzip2"]

View File

@@ -675,9 +675,11 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
),
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
@@ -743,10 +745,18 @@ def stream_chat_message_objects(
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to process chat message: {error_msg}")
except ValueError as e:
logger.exception("Failed to process chat message.")
error_msg = str(e)
yield StreamingError(error=error_msg)
db_session.rollback()
return
except Exception as e:
logger.exception("Failed to process chat message.")
error_msg = str(e)
stack_trace = traceback.format_exc()
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
@@ -786,16 +796,18 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else []
),
)
logger.debug("Committing messages")

View File

@@ -159,11 +159,18 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
# Used by celery as broker and backend
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15))
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
)
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "CERT_NONE")
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
#####
# Connector Configs
#####

View File

@@ -5,6 +5,7 @@ from typing import cast
from typing import Optional
from typing import TypeVar
from fastapi import HTTPException
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
@@ -153,15 +154,23 @@ def handle_regular_answer(
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
if new_message_request.persona_config:
raise HTTPException(
status_code=403,
detail="Slack bot does not support persona config",
)
elif new_message_request.persona_id:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context

View File

@@ -226,7 +226,7 @@ def create_chat_session(
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int,
persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,

View File

@@ -4,9 +4,11 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import DocumentSet
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings
from danswer.db.models import Tool as ToolModel
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
@@ -103,6 +105,20 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_doc_sets(
db_session: Session, doc_ids: list[int]
) -> list[DocumentSet]:
return list(
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
)
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
return list(
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
)
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,

View File

@@ -122,7 +122,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
chosen_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=True
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
@@ -866,7 +866,9 @@ class ChatSession(Base):
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
@@ -900,7 +902,6 @@ class ChatSession(Base):
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -909,7 +910,6 @@ class ChatSession(Base):
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
@@ -1347,55 +1347,6 @@ class ChannelConfig(TypedDict):
follow_up_tags: NotRequired[list[str]]
class StandardAnswerCategory(Base):
__tablename__ = "standard_answer_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)
class StandardAnswer(Base):
__tablename__ = "standard_answer"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
match_regex: Mapped[bool] = mapped_column(Boolean)
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(
"unique_keyword_active",
keyword,
active,
unique=True,
postgresql_where=(active == True), # noqa: E712
),
)
categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="standard_answers",
)
chat_messages: Mapped[list[ChatMessage]] = relationship(
"ChatMessage",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="standard_answers",
)
class SlackBotResponseType(str, PyEnum):
QUOTES = "quotes"
CITATIONS = "citations"
@@ -1421,7 +1372,7 @@ class SlackBotConfig(Base):
)
persona: Mapped[Persona | None] = relationship("Persona")
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
"StandardAnswerCategory",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="slack_bot_configs",
@@ -1651,6 +1602,55 @@ class TokenRateLimit__UserGroup(Base):
)
class StandardAnswerCategory(Base):
__tablename__ = "standard_answer_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)
class StandardAnswer(Base):
__tablename__ = "standard_answer"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
match_regex: Mapped[bool] = mapped_column(Boolean)
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(
"unique_keyword_active",
keyword,
active,
unique=True,
postgresql_where=(active == True), # noqa: E712
),
)
categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="standard_answers",
)
chat_messages: Mapped[list[ChatMessage]] = relationship(
"ChatMessage",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="standard_answers",
)
"""Tables related to Permission Sync"""

View File

@@ -210,6 +210,22 @@ def update_persona_shared_users(
)
def update_persona_public_status(
persona_id: int,
is_public: bool,
db_session: Session,
user: User | None,
) -> None:
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
raise ValueError("You don't have permission to modify this persona")
persona.is_public = is_public
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
@@ -551,6 +567,7 @@ def update_persona_visibility(
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
persona.is_visible = is_visible
db_session.commit()
@@ -563,13 +580,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return prompts
return list(prompts)
def get_prompt_by_id(

View File

@@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -14,8 +15,11 @@ from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.errors import EERequiredError
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def _build_persona_name(channel_names: list[str]) -> str:
@@ -70,6 +74,10 @@ def create_slack_bot_persona(
return persona
def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
return []
def insert_slack_bot_config(
persona_id: int | None,
channel_config: ChannelConfig,
@@ -78,14 +86,29 @@ def insert_slack_bot_config(
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
versioned_fetch_standard_answer_categories_by_ids = (
fetch_versioned_implementation_with_fallback(
"danswer.db.standard_answer",
"fetch_standard_answer_categories_by_ids",
_no_ee_standard_answer_categories,
)
)
existing_standard_answer_categories = (
versioned_fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
if len(existing_standard_answer_categories) == 0:
raise EERequiredError(
"Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories"
)
else:
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
slack_bot_config = SlackBotConfig(
persona_id=persona_id,
@@ -117,9 +140,18 @@ def update_slack_bot_config(
f"Unable to find slack bot config with ID {slack_bot_config_id}"
)
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
versioned_fetch_standard_answer_categories_by_ids = (
fetch_versioned_implementation_with_fallback(
"danswer.db.standard_answer",
"fetch_standard_answer_categories_by_ids",
_no_ee_standard_answer_categories,
)
)
existing_standard_answer_categories = (
versioned_fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(

View File

@@ -1,202 +0,0 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import StandardAnswer
from danswer.db.models import StandardAnswerCategory
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_category_validity(category_name: str) -> bool:
"""If a category name is too long, it should not be used (it will cause an error in Postgres
as the unique constraint can only apply to entries that are less than 2704 bytes).
Additionally, extremely long categories are not really usable / useful."""
if len(category_name) > 255:
logger.error(
f"Category with name '{category_name}' is too long, cannot be used"
)
return False
return True
def insert_standard_answer_category(
category_name: str, db_session: Session
) -> StandardAnswerCategory:
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category = StandardAnswerCategory(name=category_name)
db_session.add(standard_answer_category)
db_session.commit()
return standard_answer_category
def insert_standard_answer(
keyword: str,
answer: str,
category_ids: list[int],
match_regex: bool,
match_any_keywords: bool,
db_session: Session,
) -> StandardAnswer:
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer = StandardAnswer(
keyword=keyword,
answer=answer,
categories=existing_categories,
active=True,
match_regex=match_regex,
match_any_keywords=match_any_keywords,
)
db_session.add(standard_answer)
db_session.commit()
return standard_answer
def update_standard_answer(
standard_answer_id: int,
keyword: str,
answer: str,
category_ids: list[int],
match_regex: bool,
match_any_keywords: bool,
db_session: Session,
) -> StandardAnswer:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer.keyword = keyword
standard_answer.answer = answer
standard_answer.categories = list(existing_categories)
standard_answer.match_regex = match_regex
standard_answer.match_any_keywords = match_any_keywords
db_session.commit()
return standard_answer
def remove_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> None:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
standard_answer.active = False
db_session.commit()
def update_standard_answer_category(
standard_answer_category_id: int,
category_name: str,
db_session: Session,
) -> StandardAnswerCategory:
standard_answer_category = db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
if standard_answer_category is None:
raise ValueError(
f"No standard answer category with id {standard_answer_category_id}"
)
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category.name = category_name
db_session.commit()
return standard_answer_category
def fetch_standard_answer_category(
standard_answer_category_id: int,
db_session: Session,
) -> StandardAnswerCategory | None:
return db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
def fetch_standard_answer_categories_by_ids(
standard_answer_category_ids: list[int],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id.in_(standard_answer_category_ids)
)
).all()
def fetch_standard_answer_categories(
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(select(StandardAnswerCategory)).all()
def fetch_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> StandardAnswer | None:
return db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
return db_session.scalars(
select(StandardAnswer).where(StandardAnswer.active.is_(True))
).all()
def create_initial_default_standard_answer_category(db_session: Session) -> None:
default_category_id = 0
default_category_name = "General"
default_category = fetch_standard_answer_category(
standard_answer_category_id=default_category_id,
db_session=db_session,
)
if default_category is not None:
if default_category.name != default_category_name:
raise ValueError(
"DB is not in a valid initial state. "
"Default standard answer category does not have expected name."
)
return
standard_answer_category = StandardAnswerCategory(
id=default_category_id,
name=default_category_name,
)
db_session.add(standard_answer_category)
db_session.commit()

View File

@@ -558,7 +558,15 @@ class Answer:
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
yield cast(str, item)
# this should never happen, but we're seeing weird behavior here so handling for now
if not isinstance(item, str):
logger.error(
f"Received non-string item in answer stream: {item}. Skipping."
)
continue
yield item
yield from process_answer_stream_fn(_stream())

View File

@@ -62,7 +62,6 @@ from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
@@ -111,6 +110,8 @@ from danswer.server.query_and_chat.query_backend import (
from danswer.server.query_and_chat.query_backend import basic_router as query_router
from danswer.server.settings.api import admin_router as settings_admin_router
from danswer.server.settings.api import basic_router as settings_router
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
@@ -125,10 +126,10 @@ from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
@@ -186,9 +187,6 @@ def setup_postgres(db_session: Session) -> None:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
@@ -245,6 +243,12 @@ def update_default_multipass_indexing(db_session: Session) -> None:
)
update_current_search_settings(db_session, updated_settings)
# Update settings with GPU availability
settings = load_settings()
settings.gpu_enabled = gpu_available
store_settings(settings)
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
else:
logger.debug(
"Existing docs or connectors found. Skipping multipass indexing update."
@@ -591,7 +595,7 @@ def get_application() -> FastAPI:
application.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change this to the list of allowed origins if needed
allow_origins=CORS_ALLOWED_ORIGIN, # Configurable via environment variable
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],

View File

@@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
from danswer.llm.answering.answer import Answer
@@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger()
@@ -118,7 +119,17 @@ def stream_answer_objects(
one_shot=True,
danswerbot_flow=danswerbot_flow,
)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona
persona = temporary_persona if temporary_persona else chat_session.persona
llm, fast_llm = get_llms_for_persona(persona=persona)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
@@ -153,11 +164,11 @@ def stream_answer_objects(
prompt_id=query_req.prompt_id, user=None, db_session=db_session
)
if prompt is None:
if not chat_session.persona.prompts:
if not persona.prompts:
raise RuntimeError(
"Persona does not have any prompts - this should never happen"
)
prompt = chat_session.persona.prompts[0]
prompt = persona.prompts[0]
# Create the first User query message
new_user_message = create_new_chat_message(
@@ -174,9 +185,7 @@ def stream_answer_objects(
prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
chat_session.persona.num_chunks
if chat_session.persona.num_chunks is not None
else default_num_chunks
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
),
max_tokens=max_document_tokens,
)
@@ -187,16 +196,16 @@ def stream_answer_objects(
evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type,
persona=chat_session.persona,
persona=persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
)
answer_config = AnswerStyleConfig(
@@ -209,13 +218,15 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
tools=[search_tool] if search_tool else [],
force_use_tool=(
ForceUseTool(
tool_name=search_tool.name,
args={"query": rephrased_query},
force_use=True,
)
),
# for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet
@@ -223,9 +234,7 @@ def stream_answer_objects(
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
@@ -261,6 +270,7 @@ def stream_answer_objects(
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
)
yield initial_response
elif packet.id == SEARCH_DOC_CONTENT_ID:
@@ -287,6 +297,7 @@ def stream_answer_objects(
relevance_summary=evaluation_response,
)
yield evaluation_response
else:
yield packet

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
@@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
@@ -23,10 +27,49 @@ class ThreadMessage(BaseModel):
role: MessageType = MessageType.USER
class PromptConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class DocumentSetConfig(BaseModel):
id: int
class ToolConfig(BaseModel):
id: int
class PersonaConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
class DirectQARequest(ChunkContext):
persona_config: PersonaConfig | None = None
persona_id: int | None = None
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
prompt_id: int | None = None
multilingual_query_expansion: list[str] | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
@@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext):
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "DirectQARequest":
if (self.persona_config is None) == (self.persona_id is None):
raise ValueError("Exactly one of persona_config or persona_id must be set")
return self
@model_validator(mode="after")
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
if self.chain_of_thought and self.prompt_id is not None:

View File

@@ -84,6 +84,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
# Multilingual Expansion
multilingual_expansion=search_settings.multilingual_expansion,
rerank_api_url=search_settings.rerank_api_url,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
)

View File

@@ -3,6 +3,7 @@ from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from pydantic import BaseModel
@@ -20,6 +21,7 @@ from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import mark_persona_as_not_deleted
from danswer.db.persona import update_all_personas_display_priority
from danswer.db.persona import update_persona_public_status
from danswer.db.persona import update_persona_shared_users
from danswer.db.persona import update_persona_visibility
from danswer.file_store.file_store import get_default_file_store
@@ -43,6 +45,10 @@ class IsVisibleRequest(BaseModel):
is_visible: bool
class IsPublicRequest(BaseModel):
is_public: bool
@admin_router.patch("/{persona_id}/visible")
def patch_persona_visibility(
persona_id: int,
@@ -58,6 +64,25 @@ def patch_persona_visibility(
)
@basic_router.patch("/{persona_id}/public")
def patch_user_presona_public_status(
persona_id: int,
is_public_request: IsPublicRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_public_status(
persona_id=persona_id,
is_public=is_public_request.is_public,
db_session=db_session,
user=user,
)
except ValueError as e:
logger.exception("Failed to update persona public status")
raise HTTPException(status_code=403, detail=str(e))
@admin_router.put("/display-priority")
def patch_persona_display_priority(
display_priority_request: DisplayPriorityRequest,

View File

@@ -1,6 +1,4 @@
import re
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -17,13 +15,12 @@ from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
from danswer.db.models import SlackBotResponseType
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
from danswer.db.models import User
from danswer.search.models import SavedSearchSettings
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
from ee.danswer.server.manage.models import StandardAnswerCategory
if TYPE_CHECKING:
@@ -119,95 +116,6 @@ class HiddenUpdateRequest(BaseModel):
hidden: bool
class StandardAnswerCategoryCreationRequest(BaseModel):
name: str
class StandardAnswerCategory(BaseModel):
id: int
name: str
@classmethod
def from_model(
cls, standard_answer_category: StandardAnswerCategoryModel
) -> "StandardAnswerCategory":
return cls(
id=standard_answer_category.id,
name=standard_answer_category.name,
)
class StandardAnswer(BaseModel):
id: int
keyword: str
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
match_any_keywords: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
return cls(
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
match_any_keywords=standard_answer_model.match_any_keywords,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
],
)
class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
match_regex: bool
match_any_keywords: bool
@field_validator("categories", mode="before")
@classmethod
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
"At least one category must be attached to a standard answer"
)
return value
@model_validator(mode="after")
def validate_only_match_any_if_not_regex(self) -> Any:
if self.match_regex and self.match_any_keywords:
raise ValueError(
"Can only match any keywords in keyword mode, not regex mode"
)
return self
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)
class SlackBotTokens(BaseModel):
bot_token: str
app_token: str
@@ -233,6 +141,7 @@ class SlackBotConfigCreationRequest(BaseModel):
# list of user emails
follow_up_tags: list[str] | None = None
response_type: SlackBotResponseType
# XXX this is going away soon
standard_answer_categories: list[int] = Field(default_factory=list)
@field_validator("answer_filters", mode="before")
@@ -257,6 +166,7 @@ class SlackBotConfig(BaseModel):
persona: PersonaSnapshot | None
channel_config: ChannelConfig
response_type: SlackBotResponseType
# XXX this is going away soon
standard_answer_categories: list[StandardAnswerCategory]
enable_auto_filters: bool
@@ -275,6 +185,7 @@ class SlackBotConfig(BaseModel):
),
channel_config=slack_bot_config_model.channel_config,
response_type=slack_bot_config_model.response_type,
# XXX this is going away soon
standard_answer_categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in slack_bot_config_model.standard_answer_categories

View File

@@ -108,6 +108,7 @@ def create_slack_bot_config(
persona_id=persona_id,
channel_config=channel_config,
response_type=slack_bot_config_creation_request.response_type,
# XXX this is going away soon
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
db_session=db_session,
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,

View File

@@ -164,7 +164,7 @@ def get_chat_session(
chat_session_id=session_id,
description=chat_session.description,
persona_id=chat_session.persona_id,
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name if chat_session.persona else None,
current_alternate_model=chat_session.current_alternate_model,
messages=[
translate_db_message_to_chat_message_detail(

View File

@@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel):
class ChatSessionDetails(BaseModel):
id: int
name: str
persona_id: int
persona_id: int | None = None
time_created: str
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
@@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel):
class ChatSessionDetailResponse(BaseModel):
chat_session_id: int
description: str
persona_id: int
persona_name: str
persona_id: int | None = None
persona_name: str | None
messages: list[ChatMessageDetail]
time_created: datetime
shared_status: ChatSessionSharedStatus

View File

@@ -37,6 +37,7 @@ class Settings(BaseModel):
search_page_enabled: bool = True
default_page: PageType = PageType.SEARCH
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
def check_validity(self) -> None:
chat_page_enabled = self.chat_page_enabled

View File

@@ -200,6 +200,7 @@ class ImageGenerationTool(Tool):
revised_prompt=response.data[0]["revised_prompt"],
url=response.data[0]["url"],
)
except Exception as e:
logger.debug(f"Error occured during image generation: {e}")

View File

@@ -0,0 +1,3 @@
class EERequiredError(Exception):
"""This error is thrown if an Enterprise Edition feature or API is
requested but the Enterprise Edition flag is not set."""

View File

@@ -21,11 +21,11 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger
from ee.danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from ee.danswer.db.standard_answer import find_matching_standard_answers
from ee.danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
logger = setup_logger()

View File

@@ -12,6 +12,198 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_category_validity(category_name: str) -> bool:
"""If a category name is too long, it should not be used (it will cause an error in Postgres
as the unique constraint can only apply to entries that are less than 2704 bytes).
Additionally, extremely long categories are not really usable / useful."""
if len(category_name) > 255:
logger.error(
f"Category with name '{category_name}' is too long, cannot be used"
)
return False
return True
def insert_standard_answer_category(
category_name: str, db_session: Session
) -> StandardAnswerCategory:
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category = StandardAnswerCategory(name=category_name)
db_session.add(standard_answer_category)
db_session.commit()
return standard_answer_category
def insert_standard_answer(
keyword: str,
answer: str,
category_ids: list[int],
match_regex: bool,
match_any_keywords: bool,
db_session: Session,
) -> StandardAnswer:
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer = StandardAnswer(
keyword=keyword,
answer=answer,
categories=existing_categories,
active=True,
match_regex=match_regex,
match_any_keywords=match_any_keywords,
)
db_session.add(standard_answer)
db_session.commit()
return standard_answer
def update_standard_answer(
standard_answer_id: int,
keyword: str,
answer: str,
category_ids: list[int],
match_regex: bool,
match_any_keywords: bool,
db_session: Session,
) -> StandardAnswer:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer.keyword = keyword
standard_answer.answer = answer
standard_answer.categories = list(existing_categories)
standard_answer.match_regex = match_regex
standard_answer.match_any_keywords = match_any_keywords
db_session.commit()
return standard_answer
def remove_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> None:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
standard_answer.active = False
db_session.commit()
def update_standard_answer_category(
standard_answer_category_id: int,
category_name: str,
db_session: Session,
) -> StandardAnswerCategory:
standard_answer_category = db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
if standard_answer_category is None:
raise ValueError(
f"No standard answer category with id {standard_answer_category_id}"
)
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category.name = category_name
db_session.commit()
return standard_answer_category
def fetch_standard_answer_category(
standard_answer_category_id: int,
db_session: Session,
) -> StandardAnswerCategory | None:
return db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
def fetch_standard_answer_categories_by_ids(
standard_answer_category_ids: list[int],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id.in_(standard_answer_category_ids)
)
).all()
def fetch_standard_answer_categories(
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(select(StandardAnswerCategory)).all()
def fetch_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> StandardAnswer | None:
return db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
return db_session.scalars(
select(StandardAnswer).where(StandardAnswer.active.is_(True))
).all()
def create_initial_default_standard_answer_category(db_session: Session) -> None:
default_category_id = 0
default_category_name = "General"
default_category = fetch_standard_answer_category(
standard_answer_category_id=default_category_id,
db_session=db_session,
)
if default_category is not None:
if default_category.name != default_category_name:
raise ValueError(
"DB is not in a valid initial state. "
"Default standard answer category does not have expected name."
)
return
standard_answer_category = StandardAnswerCategory(
id=default_category_id,
name=default_category_name,
)
db_session.add(standard_answer_category)
db_session.commit()
def fetch_standard_answer_categories_by_names(
standard_answer_category_names: list[str],
db_session: Session,

View File

@@ -0,0 +1,98 @@
import re
from typing import Any
from pydantic import BaseModel
from pydantic import field_validator
from pydantic import model_validator
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
class StandardAnswerCategoryCreationRequest(BaseModel):
name: str
class StandardAnswerCategory(BaseModel):
id: int
name: str
@classmethod
def from_model(
cls, standard_answer_category: StandardAnswerCategoryModel
) -> "StandardAnswerCategory":
return cls(
id=standard_answer_category.id,
name=standard_answer_category.name,
)
class StandardAnswer(BaseModel):
id: int
keyword: str
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
match_any_keywords: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
return cls(
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
match_any_keywords=standard_answer_model.match_any_keywords,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
],
)
class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
match_regex: bool
match_any_keywords: bool
@field_validator("categories", mode="before")
@classmethod
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
"At least one category must be attached to a standard answer"
)
return value
@model_validator(mode="after")
def validate_only_match_any_if_not_regex(self) -> Any:
if self.match_regex and self.match_any_keywords:
raise ValueError(
"Can only match any keywords in keyword mode, not regex mode"
)
return self
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)

View File

@@ -6,19 +6,19 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.standard_answer import fetch_standard_answer
from danswer.db.standard_answer import fetch_standard_answer_categories
from danswer.db.standard_answer import fetch_standard_answer_category
from danswer.db.standard_answer import fetch_standard_answers
from danswer.db.standard_answer import insert_standard_answer
from danswer.db.standard_answer import insert_standard_answer_category
from danswer.db.standard_answer import remove_standard_answer
from danswer.db.standard_answer import update_standard_answer
from danswer.db.standard_answer import update_standard_answer_category
from danswer.server.manage.models import StandardAnswer
from danswer.server.manage.models import StandardAnswerCategory
from danswer.server.manage.models import StandardAnswerCategoryCreationRequest
from danswer.server.manage.models import StandardAnswerCreationRequest
from ee.danswer.db.standard_answer import fetch_standard_answer
from ee.danswer.db.standard_answer import fetch_standard_answer_categories
from ee.danswer.db.standard_answer import fetch_standard_answer_category
from ee.danswer.db.standard_answer import fetch_standard_answers
from ee.danswer.db.standard_answer import insert_standard_answer
from ee.danswer.db.standard_answer import insert_standard_answer_category
from ee.danswer.db.standard_answer import remove_standard_answer
from ee.danswer.db.standard_answer import update_standard_answer
from ee.danswer.db.standard_answer import update_standard_answer_category
from ee.danswer.server.manage.models import StandardAnswer
from ee.danswer.server.manage.models import StandardAnswerCategory
from ee.danswer.server.manage.models import StandardAnswerCategoryCreationRequest
from ee.danswer.server.manage.models import StandardAnswerCreationRequest
router = APIRouter(prefix="/manage")

View File

@@ -14,6 +14,7 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.process_message import ChatPacketStream
from danswer.chat.process_message import stream_chat_message_objects
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
@@ -28,6 +29,7 @@ from danswer.natural_language_processing.utils import get_tokenizer
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.search.models import SavedSearchDoc
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
@@ -65,21 +67,64 @@ def _translate_doc_response_to_simple_doc(
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
simple_search_docs: list[SimpleDoc] | None,
top_docs: list[SavedSearchDoc] | None,
) -> list[int] | None:
"""
this function returns a list of indices of the simple search docs
that were actually fed to the LLM.
"""
if final_context_docs is None or simple_search_docs is None:
if final_context_docs is None or top_docs is None:
return None
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
return [
i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
]
def _convert_packet_stream_to_response(
packets: ChatPacketStream,
) -> ChatBasicResponse:
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
answer = ""
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
# TODO: deprecate `llm_chunks_indices`
response.llm_chunks_indices = packet.llm_selected_doc_indices
elif isinstance(packet, FinalUsedContextDocsResponse):
final_context_docs = packet.final_context_docs
elif isinstance(packet, AllCitations):
response.cited_documents = {
citation.citation_num: citation.document_id
for citation in packet.citations
}
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)
return response
def remove_answer_citations(answer: str) -> str:
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
@@ -139,36 +184,7 @@ def handle_simplified_chat_message(
db_session=db_session,
)
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
answer = ""
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.message_id = packet.message_id
elif isinstance(packet, FinalUsedContextDocsResponse):
final_context_docs = packet.final_context_docs
elif isinstance(packet, AllCitations):
response.cited_documents = {
citation.citation_num: citation.document_id
for citation in packet.citations
}
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.simple_search_docs
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)
return response
return _convert_packet_stream_to_response(packets)
@router.post("/send-message-simple-with-history")
@@ -287,35 +303,4 @@ def handle_send_message_simple_with_history(
db_session=db_session,
)
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
answer = ""
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, FinalUsedContextDocsResponse):
final_context_docs = packet.final_context_docs
elif isinstance(packet, AllCitations):
response.cited_documents = {
citation.citation_num: citation.document_id
for citation in packet.citations
}
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.simple_search_docs
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)
return response
return _convert_packet_stream_to_response(packets)

View File

@@ -8,7 +8,8 @@ from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.server.manage.models import StandardAnswer
from danswer.search.models import SavedSearchDoc
from ee.danswer.server.manage.models import StandardAnswer
class StandardAnswerRequest(BaseModel):
@@ -73,10 +74,17 @@ class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
answer_citationless: str | None = None
simple_search_docs: list[SimpleDoc] | None = None
top_documents: list[SavedSearchDoc] | None = None
error_msg: str | None = None
message_id: int | None = None
llm_selected_doc_indices: list[int] | None = None
final_context_doc_indices: list[int] | None = None
# this is a map of the citation number to the document id
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
# TODO: deprecate both of these
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None

View File

@@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger()
@@ -133,12 +134,23 @@ def get_answer_with_quote(
query = query_request.messages[0].message
logger.notice(f"Received query for one shot answer API with quotes: {query}")
persona = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
if query_request.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session,
persona_config=query_request.persona_config,
user=user,
)
persona = new_persona
elif query_request.persona_id is not None:
persona = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
else:
raise KeyError("Must provide persona ID or Persona Config")
llm = get_main_llm_from_tuple(
get_default_llms() if not persona else get_llms_for_persona(persona)

View File

@@ -0,0 +1,83 @@
from typing import cast
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.one_shot_answer.models import PersonaConfig
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
def create_temporary_persona(
persona_config: PersonaConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel):
name: str | None
first_user_message: str
first_ai_message: str
persona_name: str
persona_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | Literal["mixed"] | None
@@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel):
user_email: str
name: str | None
messages: list[MessageSnapshot]
persona_name: str
persona_name: str | None
time_created: datetime
@@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str
persona_name: str | None
user_email: str
time_created: datetime
@@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str]:
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
@@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal(
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name
if chat_session.persona
else None,
time_created=chat_session.time_created,
feedback_type=feedback_type,
)
@@ -300,7 +302,7 @@ def snapshot_from_chat_session(
for message in messages
if message.message_type != MessageType.SYSTEM
],
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name if chat_session.persona else None,
time_created=chat_session.time_created,
)

View File

@@ -13,6 +13,9 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.settings.models import Settings
from danswer.server.settings.store import store_settings as store_base_settings
from danswer.utils.logger import setup_logger
from ee.danswer.db.standard_answer import (
create_initial_default_standard_answer_category,
)
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
from ee.danswer.server.enterprise_settings.store import store_analytics_script
@@ -21,6 +24,7 @@ from ee.danswer.server.enterprise_settings.store import (
)
from ee.danswer.server.enterprise_settings.store import upload_logo
logger = setup_logger()
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
@@ -146,3 +150,6 @@ def seed_db() -> None:
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)

View File

@@ -1,4 +1,5 @@
import os
from urllib.parse import urlparse
# Used for logging
SLACK_CHANNEL_ID = "channel_id"
@@ -73,3 +74,18 @@ PRESERVED_SEARCH_FIELDS = [
"passage_prefix",
"query_prefix",
]
# CORS
def validate_cors_origin(origin: str) -> None:
parsed = urlparse(origin)
if parsed.scheme not in ["http", "https"] or not parsed.netloc:
raise ValueError(f"Invalid CORS origin: '{origin}'")
CORS_ALLOWED_ORIGIN = os.environ.get("CORS_ALLOWED_ORIGIN", "*").split(",") or ["*"]
# Validate non-wildcard origins
for origin in CORS_ALLOWED_ORIGIN:
if origin != "*" and (stripped_origin := origin.strip()):
validate_cors_origin(stripped_origin)

View File

@@ -12,6 +12,7 @@ command=python danswer/background/update.py
redirect_stderr=true
autorestart=true
# Background jobs that must be run async due to long time to completion
# NOTE: due to an issue with Celery + SQLAlchemy
# (https://github.com/celery/celery/issues/7007#issuecomment-1740139367)
@@ -37,11 +38,9 @@ autorestart=true
# Job scheduler for periodic tasks
[program:celery_beat]
command=celery -A danswer.background.celery.celery_run:celery_app beat
--loglevel=INFO
--logfile=/var/log/celery_beat_supervisor.log
environment=LOG_FILE_NAME=celery_beat
redirect_stderr=true
autorestart=true
# Listens for Slack messages and responds with answers
# for all channels that the DanswerBot has been added to.
@@ -68,4 +67,4 @@ command=tail -qF
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
redirect_stderr=true
autorestart=true
autorestart=true

View File

@@ -51,6 +51,7 @@ def test_send_message_simple_with_history(reset: None) -> None:
# Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
# assert that the metadata is correct
for doc in cc_pair_1.documents:

View File

@@ -34,6 +34,7 @@ services:
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
# Gen AI Settings
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}

View File

@@ -31,6 +31,7 @@ services:
- SMTP_PASS=${SMTP_PASS:-}
- EMAIL_FROM=${EMAIL_FROM:-}
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
# Gen AI Settings
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}

View File

@@ -27,4 +27,4 @@ spec:
key: redis_password
envFrom:
- configMapRef:
name: env-configmap
name: env-configmap

View File

@@ -13,6 +13,7 @@ data:
SMTP_USER: "" # 'your-email@company.com'
SMTP_PASS: "" # 'your-gmail-password'
EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead
CORS_ALLOWED_ORIGIN: ""
# Gen AI Settings
GEN_AI_MAX_TOKENS: ""
QA_TIMEOUT: "60"

View File

@@ -25,10 +25,8 @@ import { usePopup } from "@/components/admin/connectors/Popup";
import { getDisplayNameForModel } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import { useUserGroups } from "@/lib/hooks";
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
import { checkLLMSupportsImageOutput, destructureValue } from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { checkUserIsNoAuthUser } from "@/lib/user";
@@ -47,7 +45,12 @@ import { FullLLMProvider } from "../configuration/llm/interfaces";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, StarterMessage } from "./interfaces";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
import {
buildFinalPrompt,
createPersona,
providersContainImageGeneratingSupport,
updatePersona,
} from "./lib";
import { Popover } from "@/components/popover/Popover";
import {
CameraIcon,
@@ -167,7 +170,7 @@ export function AssistantEditor({
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
);
const defaultProviderName = defaultProvider?.provider;
const defaultModelName = defaultProvider?.default_model_name;
const providerDisplayNameToProviderName = new Map<string, string>();
llmProviders.forEach((llmProvider) => {
@@ -187,10 +190,9 @@ export function AssistantEditor({
});
modelOptionsByProvider.set(llmProvider.name, providerOptions);
});
const providerSupportingImageGenerationExists = llmProviders.some(
(provider) =>
provider.provider === "openai" || provider.provider === "anthropic"
);
const providerSupportingImageGenerationExists =
providersContainImageGeneratingSupport(llmProviders);
const personaCurrentToolIds =
existingPersona?.tools.map((tool) => tool.id) || [];
@@ -342,7 +344,12 @@ export function AssistantEditor({
if (imageGenerationToolEnabled) {
if (
!checkLLMSupportsImageInput(
!checkLLMSupportsImageOutput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || ""
)
) {
@@ -453,6 +460,15 @@ export function AssistantEditor({
: false;
}
const currentLLMSupportsImageOutput = checkLLMSupportsImageOutput(
providerDisplayNameToProviderName.get(
values.llm_model_provider_override || ""
) ||
defaultProviderName ||
"",
values.llm_model_version_override || defaultModelName || ""
);
return (
<Form className="w-full text-text-950">
<div className="w-full flex gap-x-2 justify-center">
@@ -757,9 +773,7 @@ export function AssistantEditor({
<TooltipTrigger asChild>
<div
className={`w-fit ${
!checkLLMSupportsImageInput(
values.llm_model_version_override || ""
)
!currentLLMSupportsImageOutput
? "opacity-70 cursor-not-allowed"
: ""
}`}
@@ -771,17 +785,11 @@ export function AssistantEditor({
onChange={() => {
toggleToolInValues(imageGenerationTool.id);
}}
disabled={
!checkLLMSupportsImageInput(
values.llm_model_version_override || ""
)
}
disabled={!currentLLMSupportsImageOutput}
/>
</div>
</TooltipTrigger>
{!checkLLMSupportsImageInput(
values.llm_model_version_override || ""
) && (
{!currentLLMSupportsImageOutput && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
To use Image Generation, select GPT-4o or another
@@ -1051,15 +1059,15 @@ export function AssistantEditor({
<Field
name={`starter_messages[${index}].name`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
@@ -1081,15 +1089,15 @@ export function AssistantEditor({
<Field
name={`starter_messages.${index}.description`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
@@ -1112,15 +1120,15 @@ export function AssistantEditor({
<Field
name={`starter_messages[${index}].message`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
as="textarea"
autoComplete="off"
/>

View File

@@ -8,7 +8,11 @@ import { usePopup } from "@/components/admin/connectors/Popup";
import { useState, useMemo, useEffect } from "react";
import { UniqueIdentifier } from "@dnd-kit/core";
import { DraggableTable } from "@/components/table/DraggableTable";
import { deletePersona, personaComparator } from "./lib";
import {
deletePersona,
personaComparator,
togglePersonaVisibility,
} from "./lib";
import { FiEdit2 } from "react-icons/fi";
import { TrashIcon } from "@/components/icons/icons";
import { getCurrentUser } from "@/lib/user";
@@ -31,22 +35,6 @@ function PersonaTypeDisplay({ persona }: { persona: Persona }) {
return <Text>Personal {persona.owner && <>({persona.owner.email})</>}</Text>;
}
const togglePersonaVisibility = async (
personaId: number,
isVisible: boolean
) => {
const response = await fetch(`/api/admin/persona/${personaId}/visible`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
is_visible: !isVisible,
}),
});
return response;
};
export function PersonasTable({
allPersonas,
editablePersonas,

View File

@@ -1,3 +1,4 @@
import { FullLLMProvider } from "../configuration/llm/interfaces";
import { Persona, Prompt, StarterMessage } from "./interfaces";
interface PersonaCreationRequest {
@@ -318,3 +319,50 @@ export function personaComparator(a: Persona, b: Persona) {
return closerToZeroNegativesFirstComparator(a.id, b.id);
}
export const togglePersonaVisibility = async (
personaId: number,
isVisible: boolean
) => {
const response = await fetch(`/api/persona/${personaId}/visible`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
is_visible: !isVisible,
}),
});
return response;
};
export const togglePersonaPublicStatus = async (
personaId: number,
isPublic: boolean
) => {
const response = await fetch(`/api/persona/${personaId}/public`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
is_public: isPublic,
}),
});
return response;
};
export function checkPersonaRequiresImageGeneration(persona: Persona) {
for (const tool of persona.tools) {
if (tool.name === "ImageGenerationTool") {
return true;
}
}
return false;
}
export function providersContainImageGeneratingSupport(
providers: FullLLMProvider[]
) {
return providers.some((provider) => provider.provider === "openai");
}

View File

@@ -7,10 +7,8 @@ import { Button, Card, Text, Title } from "@tremor/react";
import useSWR from "swr";
import { ModelPreview } from "../../../../components/embedding/ModelSelector";
import {
AVAILABLE_CLOUD_PROVIDERS,
HostedEmbeddingModel,
CloudEmbeddingModel,
AVAILABLE_MODELS,
} from "@/components/embedding/interfaces";
import { ErrorCallout } from "@/components/ErrorCallout";
@@ -24,10 +22,7 @@ export interface EmbeddingDetails {
import { EmbeddingIcon } from "@/components/icons/icons";
import Link from "next/link";
import {
getCurrentModelCopy,
SavedSearchSettings,
} from "../../embeddings/interfaces";
import { SavedSearchSettings } from "../../embeddings/interfaces";
import UpgradingPage from "./UpgradingPage";
import { useContext } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";

View File

@@ -27,6 +27,8 @@ import {
connectorConfigs,
createConnectorInitialValues,
createConnectorValidationSchema,
defaultPruneFreqDays,
defaultRefreshFreqMinutes,
} from "@/lib/connectors/connectors";
import { Modal } from "@/components/Modal";
import GDriveMain from "./pages/gdrive/GoogleDrivePage";
@@ -154,7 +156,6 @@ export default function AddConnector({
initialValues={createConnectorInitialValues(connector)}
validationSchema={createConnectorValidationSchema(connector)}
onSubmit={async (values) => {
console.log(" Iam submiing the connector");
const {
name,
groups,
@@ -189,9 +190,9 @@ export default function AddConnector({
// Apply advanced configuration-specific transforms.
const advancedConfiguration: any = {
pruneFreq: pruneFreq * 60 * 60 * 24,
pruneFreq: (pruneFreq || defaultPruneFreqDays) * 60 * 60 * 24,
indexingStart: convertStringToDateTime(indexingStart),
refreshFreq: refreshFreq * 60,
refreshFreq: (refreshFreq || defaultRefreshFreqMinutes) * 60,
};
// Google sites-specific handling

View File

@@ -1,9 +1,16 @@
import React, { Dispatch, forwardRef, SetStateAction, useState } from "react";
import React, {
Dispatch,
forwardRef,
SetStateAction,
useContext,
useState,
} from "react";
import { Formik, Form, FormikProps } from "formik";
import * as Yup from "yup";
import {
RerankerProvider,
RerankingDetails,
RerankingModel,
rerankingModels,
} from "./interfaces";
import { FiExternalLink } from "react-icons/fi";
@@ -15,6 +22,7 @@ import {
import { Modal } from "@/components/Modal";
import { Button } from "@tremor/react";
import { TextFormField } from "@/components/admin/connectors/Field";
import { SettingsContext } from "@/components/settings/SettingsProvider";
interface RerankingDetailsFormProps {
setRerankingDetails: Dispatch<SetStateAction<RerankingDetails>>;
@@ -38,10 +46,15 @@ const RerankingDetailsForm = forwardRef<
},
ref
) => {
const [showGpuWarningModalModel, setShowGpuWarningModalModel] =
useState<RerankingModel | null>(null);
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
useState(false);
const combinedSettings = useContext(SettingsContext);
const gpuEnabled = combinedSettings?.settings.gpu_enabled;
return (
<Formik
innerRef={ref}
@@ -169,6 +182,11 @@ const RerankingDetailsForm = forwardRef<
RerankerProvider.LITELLM
) {
setShowLiteLLMConfigurationModal(true);
} else if (
!card.rerank_provider_type &&
!gpuEnabled
) {
setShowGpuWarningModalModel(card);
}
if (!isSelected) {
@@ -225,6 +243,33 @@ const RerankingDetailsForm = forwardRef<
})}
</div>
{showGpuWarningModalModel && (
<Modal
onOutsideClick={() => setShowGpuWarningModalModel(null)}
width="w-[500px] flex flex-col"
title="GPU Not Enabled"
>
<>
<p className="text-error font-semibold">Warning:</p>
<p>
Local reranking models require significant computational
resources and may perform slowly without GPU
acceleration. Consider switching to GPU-enabled
infrastructure or using a cloud-based alternative for
better performance.
</p>
<div className="flex justify-end">
<Button
onClick={() => setShowGpuWarningModalModel(null)}
color="blue"
size="xs"
>
Understood
</Button>
</div>
</>
</Modal>
)}
{showLiteLLMConfigurationModal && (
<Modal
onOutsideClick={() => {

View File

@@ -5,6 +5,7 @@ export interface Settings {
maximum_chat_retention_days: number | null;
notifications: Notification[];
needs_reindexing: boolean;
gpu_enabled: boolean;
}
export interface Notification {

View File

@@ -1,14 +1,7 @@
import { User } from "@/lib/types";
import { Persona } from "../admin/assistants/interfaces";
import { checkUserOwnsAssistant } from "@/lib/assistants/checkOwnership";
import {
FiImage,
FiLock,
FiMoreHorizontal,
FiSearch,
FiUnlock,
} from "react-icons/fi";
import { CustomTooltip } from "@/components/tooltip/CustomTooltip";
import { FiLock, FiUnlock } from "react-icons/fi";
export function AssistantSharedStatusDisplay({
assistant,
@@ -56,20 +49,6 @@ export function AssistantSharedStatusDisplay({
)}
</div>
)}
<div className="relative mt-4 text-xs flex text-subtle">
<span className="font-medium">Powers:</span>{" "}
{assistant.tools.length == 0 ? (
<p className="ml-2">None</p>
) : (
assistant.tools.map((tool, ind) => {
if (tool.name === "SearchTool") {
return <FiSearch key={ind} className="ml-1 h-3 w-3 my-auto" />;
} else if (tool.name === "ImageGenerationTool") {
return <FiImage key={ind} className="ml-1 h-3 w-3 my-auto" />;
}
})
)}
</div>
</div>
);
}

View File

@@ -6,11 +6,15 @@ import { Persona } from "@/app/admin/assistants/interfaces";
import { Divider, Text } from "@tremor/react";
import {
FiEdit2,
FiFigma,
FiMenu,
FiMinus,
FiMoreHorizontal,
FiPlus,
FiSearch,
FiShare,
FiShare2,
FiToggleLeft,
FiTrash,
FiX,
} from "react-icons/fi";
@@ -50,11 +54,14 @@ import {
verticalListSortingStrategy,
} from "@dnd-kit/sortable";
import { useSortable } from "@dnd-kit/sortable";
import { CSS } from "@dnd-kit/utilities";
import { DragHandle } from "@/components/table/DragHandle";
import { deletePersona } from "@/app/admin/assistants/lib";
import {
deletePersona,
togglePersonaPublicStatus,
} from "@/app/admin/assistants/lib";
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
import { MakePublicAssistantModal } from "@/app/chat/modal/MakePublicAssistantModal";
function DraggableAssistantListItem(props: any) {
const {
@@ -81,7 +88,7 @@ function DraggableAssistantListItem(props: any) {
<DragHandle />
</div>
<div className="flex-grow">
<AssistantListItem del {...props} />
<AssistantListItem {...props} />
</div>
</div>
);
@@ -95,6 +102,7 @@ function AssistantListItem({
isVisible,
setPopup,
deleteAssistant,
shareAssistant,
}: {
assistant: Persona;
user: User | null;
@@ -102,7 +110,7 @@ function AssistantListItem({
allAssistantIds: string[];
isVisible: boolean;
deleteAssistant: Dispatch<SetStateAction<Persona | null>>;
shareAssistant: Dispatch<SetStateAction<Persona | null>>;
setPopup: (popupSpec: PopupSpec | null) => void;
}) {
const router = useRouter();
@@ -258,6 +266,18 @@ function AssistantListItem({
) : (
<></>
),
isOwnedByUser ? (
<div
key="delete"
className="flex items-center gap-x-2"
onClick={() => shareAssistant(assistant)}
>
{assistant.is_public ? <FiMinus /> : <FiPlus />} Make{" "}
{assistant.is_public ? "Private" : "Public"}
</div>
) : (
<></>
),
]}
</DefaultPopover>
</div>
@@ -286,10 +306,15 @@ export function AssistantsList({
user?.preferences?.chosen_assistants &&
!user?.preferences?.chosen_assistants?.includes(assistant.id)
);
const allAssistantIds = assistants.map((assistant) =>
assistant.id.toString()
);
const [deletingPersona, setDeletingPersona] = useState<Persona | null>(null);
const [makePublicPersona, setMakePublicPersona] = useState<Persona | null>(
null
);
const { popup, setPopup } = usePopup();
const router = useRouter();
@@ -307,7 +332,7 @@ export function AssistantsList({
async function handleDragEnd(event: DragEndEvent) {
const { active, over } = event;
filteredAssistants;
if (over && active.id !== over.id) {
setFilteredAssistants((assistants) => {
const oldIndex = assistants.findIndex(
@@ -351,6 +376,20 @@ export function AssistantsList({
/>
)}
{makePublicPersona && (
<MakePublicAssistantModal
isPublic={makePublicPersona.is_public}
onClose={() => setMakePublicPersona(null)}
onShare={async (newPublicStatus: boolean) => {
await togglePersonaPublicStatus(
makePublicPersona.id,
newPublicStatus
);
router.refresh();
}}
/>
)}
<div className="mx-auto mobile:w-[90%] desktop:w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar">
<AssistantsPageTitle>My Assistants</AssistantsPageTitle>
@@ -403,6 +442,7 @@ export function AssistantsList({
{filteredAssistants.map((assistant, index) => (
<DraggableAssistantListItem
deleteAssistant={setDeletingPersona}
shareAssistant={setMakePublicPersona}
key={assistant.id}
assistant={assistant}
user={user}
@@ -431,6 +471,7 @@ export function AssistantsList({
{ownedButHiddenAssistants.map((assistant, index) => (
<AssistantListItem
deleteAssistant={setDeletingPersona}
shareAssistant={setMakePublicPersona}
key={assistant.id}
assistant={assistant}
user={user}

View File

@@ -426,7 +426,7 @@ export function ChatPage({
}, [existingChatSessionId]);
const [message, setMessage] = useState(
searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || ""
searchParams.get(SEARCH_PARAM_NAMES.USER_PROMPT) || ""
);
const [completeMessageDetail, setCompleteMessageDetail] = useState<

View File

@@ -549,6 +549,7 @@ export function ChatInputBar({
tab
content={(close, ref) => (
<LlmTab
currentAssistant={alternativeAssistant || selectedAssistant}
openModelSettings={openModelSettings}
currentLlm={
llmOverrideManager.llmOverride.modelName ||

View File

@@ -582,7 +582,7 @@ export function personaIncludesImage(selectedPersona: Persona) {
const PARAMS_TO_SKIP = [
SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD,
SEARCH_PARAM_NAMES.USER_MESSAGE,
SEARCH_PARAM_NAMES.USER_PROMPT,
SEARCH_PARAM_NAMES.TITLE,
// only use these if explicitly passed in
SEARCH_PARAM_NAMES.CHAT_ID,

View File

@@ -1,5 +1,4 @@
import React from "react";
import { useState, ReactNode } from "react";
import React, { useState, ReactNode, useCallback, useMemo, memo } from "react";
import { FiCheck, FiCopy } from "react-icons/fi";
const CODE_BLOCK_PADDING_TYPE = { padding: "1rem" };
@@ -11,21 +10,109 @@ interface CodeBlockProps {
[key: string]: any;
}
export function CodeBlock({
export const CodeBlock = memo(function CodeBlock({
className = "",
children,
content,
...props
}: CodeBlockProps) {
const language = className
.split(" ")
.filter((cls) => cls.startsWith("language-"))
.map((cls) => cls.replace("language-", ""))
.join(" ");
const [copied, setCopied] = useState(false);
const language = useMemo(() => {
return className
.split(" ")
.filter((cls) => cls.startsWith("language-"))
.map((cls) => cls.replace("language-", ""))
.join(" ");
}, [className]);
const codeText = useMemo(() => {
let codeText: string | null = null;
if (
props.node?.position?.start?.offset &&
props.node?.position?.end?.offset
) {
codeText = content.slice(
props.node.position.start.offset,
props.node.position.end.offset
);
codeText = codeText.trim();
// Find the last occurrence of closing backticks
const lastBackticksIndex = codeText.lastIndexOf("```");
if (lastBackticksIndex !== -1) {
codeText = codeText.slice(0, lastBackticksIndex + 3);
}
// Remove the language declaration and trailing backticks
const codeLines = codeText.split("\n");
if (
codeLines.length > 1 &&
(codeLines[0].startsWith("```") ||
codeLines[0].trim().startsWith("```"))
) {
codeLines.shift(); // Remove the first line with the language declaration
if (
codeLines[codeLines.length - 1] === "```" ||
codeLines[codeLines.length - 1]?.trim() === "```"
) {
codeLines.pop(); // Remove the last line with the trailing backticks
}
const minIndent = codeLines
.filter((line) => line.trim().length > 0)
.reduce((min, line) => {
const match = line.match(/^\s*/);
return Math.min(min, match ? match[0].length : 0);
}, Infinity);
const formattedCodeLines = codeLines.map((line) =>
line.slice(minIndent)
);
codeText = formattedCodeLines.join("\n");
}
}
// handle unknown languages. They won't have a `node.position.start.offset`
if (!codeText) {
const findTextNode = (node: any): string | null => {
if (node.type === "text") {
return node.value;
}
let finalResult = "";
if (node.children) {
for (const child of node.children) {
const result = findTextNode(child);
if (result) {
finalResult += result;
}
}
}
return finalResult;
};
codeText = findTextNode(props.node);
}
return codeText;
}, [content, props.node]);
const handleCopy = useCallback(
(event: React.MouseEvent) => {
event.preventDefault();
if (!codeText) {
return;
}
navigator.clipboard.writeText(codeText).then(() => {
setCopied(true);
setTimeout(() => setCopied(false), 2000);
});
},
[codeText]
);
if (!language) {
// this is the case of a single "`" e.g. `hi`
if (typeof children === "string") {
return <code className={className}>{children}</code>;
}
@@ -39,82 +126,6 @@ export function CodeBlock({
);
}
let codeText: string | null = null;
if (
props.node?.position?.start?.offset &&
props.node?.position?.end?.offset
) {
codeText = content.slice(
props.node.position.start.offset,
props.node.position.end.offset
);
codeText = codeText.trim();
// Find the last occurrence of closing backticks
const lastBackticksIndex = codeText.lastIndexOf("```");
if (lastBackticksIndex !== -1) {
codeText = codeText.slice(0, lastBackticksIndex + 3);
}
// Remove the language declaration and trailing backticks
const codeLines = codeText.split("\n");
if (
codeLines.length > 1 &&
(codeLines[0].startsWith("```") || codeLines[0].trim().startsWith("```"))
) {
codeLines.shift(); // Remove the first line with the language declaration
if (
codeLines[codeLines.length - 1] === "```" ||
codeLines[codeLines.length - 1]?.trim() === "```"
) {
codeLines.pop(); // Remove the last line with the trailing backticks
}
const minIndent = codeLines
.filter((line) => line.trim().length > 0)
.reduce((min, line) => {
const match = line.match(/^\s*/);
return Math.min(min, match ? match[0].length : 0);
}, Infinity);
const formattedCodeLines = codeLines.map((line) => line.slice(minIndent));
codeText = formattedCodeLines.join("\n");
}
}
// handle unknown languages. They won't have a `node.position.start.offset`
if (!codeText) {
const findTextNode = (node: any): string | null => {
if (node.type === "text") {
return node.value;
}
let finalResult = "";
if (node.children) {
for (const child of node.children) {
const result = findTextNode(child);
if (result) {
finalResult += result;
}
}
}
return finalResult;
};
codeText = findTextNode(props.node);
}
const handleCopy = (event: React.MouseEvent) => {
event.preventDefault();
if (!codeText) {
return;
}
navigator.clipboard.writeText(codeText).then(() => {
setCopied(true);
setTimeout(() => setCopied(false), 2000); // Reset copy status after 2 seconds
});
};
return (
<div className="overflow-x-hidden">
<div className="flex mx-3 py-2 text-xs">
@@ -143,4 +154,4 @@ export function CodeBlock({
</pre>
</div>
);
}
});

View File

@@ -0,0 +1,33 @@
import { Citation } from "@/components/search/results/Citation";
import React, { memo } from "react";
export const MemoizedLink = memo((props: any) => {
const { node, ...rest } = props;
const value = rest.children;
if (value?.toString().startsWith("*")) {
return (
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
);
} else if (value?.toString().startsWith("[")) {
return <Citation link={rest?.href}>{rest.children}</Citation>;
} else {
return (
<a
onMouseDown={() =>
rest.href ? window.open(rest.href, "_blank") : undefined
}
className="cursor-pointer text-link hover:text-link-hover"
>
{rest.children}
</a>
);
}
});
export const MemoizedParagraph = memo(({ node, ...props }: any) => (
<p {...props} className="text-default" />
));
MemoizedLink.displayName = "MemoizedLink";
MemoizedParagraph.displayName = "MemoizedParagraph";

View File

@@ -8,14 +8,7 @@ import {
FiGlobe,
} from "react-icons/fi";
import { FeedbackType } from "../types";
import {
Dispatch,
SetStateAction,
useContext,
useEffect,
useRef,
useState,
} from "react";
import React, { useContext, useEffect, useMemo, useRef, useState } from "react";
import ReactMarkdown from "react-markdown";
import {
DanswerDocument,
@@ -46,12 +39,7 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Citation } from "@/components/search/results/Citation";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import {
ThumbsUpIcon,
ThumbsDownIcon,
LikeFeedback,
DislikeFeedback,
} from "@/components/icons/icons";
import { LikeFeedback, DislikeFeedback } from "@/components/icons/icons";
import {
CustomTooltip,
TooltipGroup,
@@ -65,6 +53,7 @@ import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage";
import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -367,41 +356,8 @@ export const AIMessage = ({
key={messageId}
className="prose max-w-full text-base"
components={{
a: (props) => {
const { node, ...rest } = props;
const value = rest.children;
if (value?.toString().startsWith("*")) {
return (
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
);
} else if (
value?.toString().startsWith("[")
) {
// for some reason <a> tags cause the onClick to not apply
// and the links are unclickable
// TODO: fix the fact that you have to double click to follow link
// for the first link
return (
<Citation link={rest?.href}>
{rest.children}
</Citation>
);
} else {
return (
<a
onMouseDown={() =>
rest.href
? window.open(rest.href, "_blank")
: undefined
}
className="cursor-pointer text-link hover:text-link-hover"
>
{rest.children}
</a>
);
}
},
a: MemoizedLink,
p: MemoizedParagraph,
code: (props) => (
<CodeBlock
className="w-full"
@@ -409,9 +365,6 @@ export const AIMessage = ({
content={content as string}
/>
),
p: ({ node, ...props }) => (
<p {...props} className="text-default" />
),
}}
remarkPlugins={[remarkGfm]}
rehypePlugins={[

View File

@@ -0,0 +1,72 @@
import { ModalWrapper } from "@/components/modals/ModalWrapper";
import { Button, Divider, Text } from "@tremor/react";
export function MakePublicAssistantModal({
isPublic,
onShare,
onClose,
}: {
isPublic: boolean;
onShare: (shared: boolean) => void;
onClose: () => void;
}) {
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-3xl">
<div className="space-y-6">
<h2 className="text-2xl font-bold text-emphasis">
{isPublic ? "Public Assistant" : "Make Assistant Public"}
</h2>
<Text>
This assistant is currently{" "}
<span className="font-semibold">
{isPublic ? "public" : "private"}
</span>
.
{isPublic
? " Anyone can currently access this assistant."
: " Only you can access this assistant."}
</Text>
<Divider />
{isPublic ? (
<div className="space-y-4">
<Text>
To restrict access to this assistant, you can make it private
again.
</Text>
<Button
onClick={async () => {
await onShare?.(false);
onClose();
}}
size="sm"
color="red"
>
Make Assistant Private
</Button>
</div>
) : (
<div className="space-y-4">
<Text>
Making this assistant public will allow anyone with the link to
view and use it. Ensure that all content and capabilities of the
assistant are safe to share.
</Text>
<Button
onClick={async () => {
await onShare?.(true);
onClose();
}}
size="sm"
color="green"
>
Make Assistant Public
</Button>
</div>
)}
</div>
</ModalWrapper>
);
}

View File

@@ -4,10 +4,15 @@ import React, { forwardRef, useCallback, useState } from "react";
import { debounce } from "lodash";
import { Text } from "@tremor/react";
import { Persona } from "@/app/admin/assistants/interfaces";
import { destructureValue, structureValue } from "@/lib/llm/utils";
import {
checkLLMSupportsImageInput,
destructureValue,
structureValue,
} from "@/lib/llm/utils";
import { updateModelOverrideForChatSession } from "../../lib";
import { GearIcon } from "@/components/icons/icons";
import { LlmList } from "@/components/llm/LLMList";
import { checkPersonaRequiresImageGeneration } from "@/app/admin/assistants/lib";
interface LlmTabProps {
llmOverrideManager: LlmOverrideManager;
@@ -15,13 +20,24 @@ interface LlmTabProps {
openModelSettings: () => void;
chatSessionId?: number;
close: () => void;
currentAssistant: Persona;
}
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
(
{ llmOverrideManager, chatSessionId, currentLlm, close, openModelSettings },
{
llmOverrideManager,
chatSessionId,
currentLlm,
close,
openModelSettings,
currentAssistant,
},
ref
) => {
const requiresImageGeneration =
checkPersonaRequiresImageGeneration(currentAssistant);
const { llmProviders } = useChatContext();
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
@@ -55,6 +71,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
</button>
</div>
<LlmList
requiresImageGeneration={requiresImageGeneration}
llmProviders={llmProviders}
currentLlm={currentLlm}
onSelect={(value: string | null) => {

View File

@@ -1,453 +0,0 @@
import React, { useEffect, useRef, useState } from "react";
import { DocumentSet, Tag, ValidSources } from "@/lib/types";
import { SourceMetadata } from "@/lib/search/interfaces";
import {
FiBook,
FiBookmark,
FiCalendar,
FiFilter,
FiMap,
FiTag,
FiX,
} from "react-icons/fi";
import { DateRangePickerValue } from "@tremor/react";
import { listSourceMetadata } from "@/lib/sources";
import { SourceIcon } from "@/components/SourceIcon";
import { BasicClickable } from "@/components/BasicClickable";
import { ControlledPopup, DefaultDropdownElement } from "@/components/Dropdown";
import { getXDaysAgo } from "@/lib/dateUtils";
import { SourceSelectorProps } from "@/components/search/filtering/Filters";
import { containsObject, objectsAreEquivalent } from "@/lib/contains";
enum FilterType {
Source = "Source",
KnowledgeSet = "Knowledge Set",
TimeRange = "Time Range",
Tag = "Tag",
}
function SelectedBubble({
children,
onClick,
}: {
children: string | JSX.Element;
onClick: () => void;
}) {
return (
<div
className={
"flex text-xs cursor-pointer items-center border border-border " +
"py-1 rounded-lg px-2 w-fit select-none hover:bg-hover"
}
onClick={onClick}
>
{children}
<FiX className="ml-2" size={14} />
</div>
);
}
function SelectFilterType({
onSelect,
hasSources,
hasKnowledgeSets,
hasTags,
}: {
onSelect: (filterType: FilterType) => void;
hasSources: boolean;
hasKnowledgeSets: boolean;
hasTags: boolean;
}) {
return (
<div className="w-64">
{hasSources && (
<DefaultDropdownElement
key={FilterType.Source}
name={FilterType.Source}
icon={FiMap}
onSelect={() => onSelect(FilterType.Source)}
isSelected={false}
/>
)}
{hasKnowledgeSets && (
<DefaultDropdownElement
key={FilterType.KnowledgeSet}
name={FilterType.KnowledgeSet}
icon={FiBook}
onSelect={() => onSelect(FilterType.KnowledgeSet)}
isSelected={false}
/>
)}
{hasTags && (
<DefaultDropdownElement
key={FilterType.Tag}
name={FilterType.Tag}
icon={FiTag}
onSelect={() => onSelect(FilterType.Tag)}
isSelected={false}
/>
)}
<DefaultDropdownElement
key={FilterType.TimeRange}
name={FilterType.TimeRange}
icon={FiCalendar}
onSelect={() => onSelect(FilterType.TimeRange)}
isSelected={false}
/>
</div>
);
}
function SourcesSection({
sources,
selectedSources,
onSelect,
}: {
sources: SourceMetadata[];
selectedSources: string[];
onSelect: (source: SourceMetadata) => void;
}) {
return (
<div className="w-64">
{sources.map((source) => (
<DefaultDropdownElement
key={source.internalName}
name={source.displayName}
icon={source.icon}
onSelect={() => onSelect(source)}
isSelected={selectedSources.includes(source.internalName)}
includeCheckbox
/>
))}
</div>
);
}
function KnowledgeSetsSection({
documentSets,
selectedDocumentSets,
onSelect,
}: {
documentSets: DocumentSet[];
selectedDocumentSets: string[];
onSelect: (documentSetName: string) => void;
}) {
return (
<div className="w-64">
{documentSets.map((documentSet) => (
<DefaultDropdownElement
key={documentSet.name}
name={documentSet.name}
icon={FiBookmark}
onSelect={() => onSelect(documentSet.name)}
isSelected={selectedDocumentSets.includes(documentSet.name)}
includeCheckbox
/>
))}
</div>
);
}
const LAST_30_DAYS = "Last 30 days";
const LAST_7_DAYS = "Last 7 days";
const TODAY = "Today";
function TimeRangeSection({
selectedTimeRange,
onSelect,
}: {
selectedTimeRange: string | null;
onSelect: (timeRange: DateRangePickerValue) => void;
}) {
return (
<div className="w-64">
<DefaultDropdownElement
key={LAST_30_DAYS}
name={LAST_30_DAYS}
onSelect={() =>
onSelect({
to: new Date(),
from: getXDaysAgo(30),
selectValue: LAST_30_DAYS,
})
}
isSelected={selectedTimeRange === LAST_30_DAYS}
/>
<DefaultDropdownElement
key={LAST_7_DAYS}
name={LAST_7_DAYS}
onSelect={() =>
onSelect({
to: new Date(),
from: getXDaysAgo(7),
selectValue: LAST_7_DAYS,
})
}
isSelected={selectedTimeRange === LAST_7_DAYS}
/>
<DefaultDropdownElement
key={TODAY}
name={TODAY}
onSelect={() =>
onSelect({
to: new Date(),
from: getXDaysAgo(1),
selectValue: TODAY,
})
}
isSelected={selectedTimeRange === TODAY}
/>
</div>
);
}
function TagsSection({
availableTags,
selectedTags,
onSelect,
}: {
availableTags: Tag[];
selectedTags: Tag[];
onSelect: (tag: Tag) => void;
}) {
const [filterValue, setFilterValue] = useState("");
const inputRef = useRef<HTMLInputElement>(null);
useEffect(() => {
if (inputRef.current) {
inputRef.current.focus();
}
}, []);
const filterValueLower = filterValue.toLowerCase();
const filteredTags = filterValueLower
? availableTags.filter(
(tags) =>
tags.tag_value.toLowerCase().startsWith(filterValueLower) ||
tags.tag_key.toLowerCase().startsWith(filterValueLower)
)
: availableTags;
return (
<div className="w-96">
<div className="max-h-48 overflow-y-auto">
{filteredTags.length > 0 ? (
filteredTags.map((tag) => (
<DefaultDropdownElement
key={tag.tag_key + tag.tag_value}
name={
<div className="max-w-full break-all line-clamp-1 text-ellipsis">
{tag.tag_key}
<b>=</b>
{tag.tag_value}
</div>
}
onSelect={() => onSelect(tag)}
isSelected={selectedTags.includes(tag)}
includeCheckbox
/>
))
) : (
<div className="text-sm px-2 py-2">No matching tags found</div>
)}
</div>
<div className="mx-2 mb-2 pt-2 border-t border-border">
<input
ref={inputRef}
className="w-full border border-border py-0.5 px-2 rounded text-sm h-8 "
placeholder="Find a tag"
value={filterValue}
onChange={(event) => setFilterValue(event.target.value)}
/>
</div>
</div>
);
}
export function ChatFilters({
timeRange,
setTimeRange,
selectedSources,
setSelectedSources,
selectedDocumentSets,
setSelectedDocumentSets,
selectedTags,
setSelectedTags,
availableDocumentSets,
existingSources,
availableTags,
}: SourceSelectorProps) {
const [filtersOpen, setFiltersOpen] = useState(false);
const handleFiltersToggle = (value: boolean) => {
setSelectedFilterType(null);
setFiltersOpen(value);
};
const [selectedFilterType, setSelectedFilterType] =
useState<FilterType | null>(null);
const handleSourceSelect = (source: SourceMetadata) => {
setSelectedSources((prev: SourceMetadata[]) => {
const prevSourceNames = prev.map((source) => source.internalName);
if (prevSourceNames.includes(source.internalName)) {
return prev.filter((s) => s.internalName !== source.internalName);
} else {
return [...prev, source];
}
});
};
const handleDocumentSetSelect = (documentSetName: string) => {
setSelectedDocumentSets((prev: string[]) => {
if (prev.includes(documentSetName)) {
return prev.filter((s) => s !== documentSetName);
} else {
return [...prev, documentSetName];
}
});
};
const handleTagToggle = (tag: Tag) => {
setSelectedTags((prev) => {
if (containsObject(prev, tag)) {
return prev.filter((t) => !objectsAreEquivalent(t, tag));
} else {
return [...prev, tag];
}
});
};
const allSources = listSourceMetadata();
const availableSources = allSources.filter((source) =>
existingSources.includes(source.internalName)
);
let popupDisplay = null;
if (selectedFilterType === FilterType.Source) {
popupDisplay = (
<SourcesSection
sources={availableSources}
selectedSources={selectedSources.map((source) => source.internalName)}
onSelect={handleSourceSelect}
/>
);
} else if (selectedFilterType === FilterType.KnowledgeSet) {
popupDisplay = (
<KnowledgeSetsSection
documentSets={availableDocumentSets}
selectedDocumentSets={selectedDocumentSets}
onSelect={handleDocumentSetSelect}
/>
);
} else if (selectedFilterType === FilterType.TimeRange) {
popupDisplay = (
<TimeRangeSection
selectedTimeRange={timeRange?.selectValue || null}
onSelect={(timeRange) => {
setTimeRange(timeRange);
handleFiltersToggle(!filtersOpen);
}}
/>
);
} else if (selectedFilterType === FilterType.Tag) {
popupDisplay = (
<TagsSection
availableTags={availableTags}
selectedTags={selectedTags}
onSelect={handleTagToggle}
/>
);
} else {
popupDisplay = (
<SelectFilterType
onSelect={(filterType) => setSelectedFilterType(filterType)}
hasSources={availableSources.length > 0}
hasKnowledgeSets={availableDocumentSets.length > 0}
hasTags={availableTags.length > 0}
/>
);
}
return (
<div className="flex">
<ControlledPopup
isOpen={filtersOpen}
setIsOpen={handleFiltersToggle}
popupContent={popupDisplay}
>
<div className="flex">
<BasicClickable onClick={() => handleFiltersToggle(!filtersOpen)}>
<div className="flex text-xs">
<FiFilter className="my-auto mr-1" /> Filter
</div>
</BasicClickable>
</div>
</ControlledPopup>
<div className="flex ml-4">
{((timeRange && timeRange.selectValue !== undefined) ||
selectedSources.length > 0 ||
selectedDocumentSets.length > 0) && (
<p className="text-xs my-auto mr-1">Currently applied:</p>
)}
<div className="flex flex-wrap gap-x-2">
{timeRange && timeRange.selectValue && (
<SelectedBubble onClick={() => setTimeRange(null)}>
<div className="flex">{timeRange.selectValue}</div>
</SelectedBubble>
)}
{existingSources.length > 0 &&
selectedSources.map((source) => (
<SelectedBubble
key={source.internalName}
onClick={() => handleSourceSelect(source)}
>
<>
<SourceIcon sourceType={source.internalName} iconSize={16} />
<span className="ml-2">{source.displayName}</span>
</>
</SelectedBubble>
))}
{selectedDocumentSets.length > 0 &&
selectedDocumentSets.map((documentSetName) => (
<SelectedBubble
key={documentSetName}
onClick={() => handleDocumentSetSelect(documentSetName)}
>
<>
<div>
<FiBookmark />
</div>
<span className="ml-2">{documentSetName}</span>
</>
</SelectedBubble>
))}
{selectedTags.length > 0 &&
selectedTags.map((tag) => (
<SelectedBubble
key={tag.tag_key + tag.tag_value}
onClick={() => handleTagToggle(tag)}
>
<>
<div>
<FiTag />
</div>
<span className="ml-1 max-w-[100px] text-ellipsis line-clamp-1 break-all">
{tag.tag_key}
<b>=</b>
{tag.tag_value}
</span>
</>
</SelectedBubble>
))}
</div>
</div>
</div>
);
}

View File

@@ -10,7 +10,7 @@ export const SEARCH_PARAM_NAMES = {
MODEL_VERSION: "model-version",
SYSTEM_PROMPT: "system-prompt",
// user message
USER_MESSAGE: "user-message",
USER_PROMPT: "user-prompt",
SUBMIT_ON_LOAD: "submit-on-load",
// chat title
TITLE: "title",

View File

@@ -0,0 +1,9 @@
import { getXDaysAgo, getXYearsAgo } from "@/lib/dateUtils";
export const timeRangeValues = [
{ label: "Last 2 years", value: getXYearsAgo(2) },
{ label: "Last year", value: getXYearsAgo(1) },
{ label: "Last 30 days", value: getXDaysAgo(30) },
{ label: "Last 7 days", value: getXDaysAgo(7) },
{ label: "Today", value: getXDaysAgo(1) },
];

View File

@@ -44,7 +44,7 @@ export function Modal({
return (
<div
onMouseDown={handleMouseDown}
className={`fixed inset-0 bg-black bg-opacity-25 backdrop-blur-sm
className={`fixed inset-0 bg-black bg-opacity-25 backdrop-blur-sm h-full
flex items-center justify-center z-50 transition-opacity duration-300 ease-in-out`}
>
<div
@@ -72,7 +72,7 @@ export function Modal({
</div>
)}
<div className="flex w-full flex-col justify-stretch">
<div className="w-full flex flex-col h-full justify-stretch">
{title && (
<>
<div className="flex mb-4">

View File

@@ -1,4 +1,8 @@
import { CodeBlock } from "@/app/chat/message/CodeBlock";
import {
MemoizedLink,
MemoizedParagraph,
} from "@/app/chat/message/MemoizedTextComponents";
import React from "react";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
@@ -18,17 +22,8 @@ export const MinimalMarkdown: React.FC<MinimalMarkdownProps> = ({
<ReactMarkdown
className={`w-full text-wrap break-word ${className}`}
components={{
a: ({ node, ...props }) => (
<a
{...props}
className="text-sm text-link hover:text-link-hover"
target="_blank"
rel="noopener noreferrer"
/>
),
p: ({ node, ...props }) => (
<p {...props} className="text-wrap break-word text-sm m-0 w-full" />
),
a: MemoizedLink,
p: MemoizedParagraph,
code: useCodeBlock
? (props) => (
<CodeBlock className="w-full" {...props} content={content} />

View File

@@ -0,0 +1,32 @@
import { DefaultDropdownElement } from "../Dropdown";
export function TimeRangeSelector({
value,
onValueChange,
className,
timeRangeValues,
}: {
value: any;
onValueChange: any;
className: any;
timeRangeValues: { label: string; value: Date }[];
}) {
return (
<div className={className}>
{timeRangeValues.map((timeRangeValue) => (
<DefaultDropdownElement
key={timeRangeValue.label}
name={timeRangeValue.label}
onSelect={() =>
onValueChange({
to: new Date(),
from: timeRangeValue.value,
selectValue: timeRangeValue.label,
})
}
isSelected={value?.selectValue === timeRangeValue.label}
/>
))}
</div>
);
}

View File

@@ -1,6 +1,6 @@
import React from "react";
import { getDisplayNameForModel } from "@/lib/hooks";
import { structureValue } from "@/lib/llm/utils";
import { checkLLMSupportsImageInput, structureValue } from "@/lib/llm/utils";
import {
getProviderIcon,
LLMProviderDescriptor,
@@ -13,6 +13,7 @@ interface LlmListProps {
userDefault?: string | null;
scrollable?: boolean;
hideProviderIcon?: boolean;
requiresImageGeneration?: boolean;
}
export const LlmList: React.FC<LlmListProps> = ({
@@ -21,6 +22,7 @@ export const LlmList: React.FC<LlmListProps> = ({
onSelect,
userDefault,
scrollable,
requiresImageGeneration,
}) => {
const llmOptionsByProvider: {
[provider: string]: {
@@ -76,21 +78,26 @@ export const LlmList: React.FC<LlmListProps> = ({
User Default (currently {getDisplayNameForModel(userDefault)})
</button>
)}
{llmOptions.map(({ name, icon, value }, index) => (
<button
type="button"
key={index}
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
currentLlm == name
? "bg-background-200"
: "bg-background hover:bg-background-100"
} text-left rounded`}
onClick={() => onSelect(value)}
>
{icon({ size: 16 })}
{getDisplayNameForModel(name)}
</button>
))}
{llmOptions.map(({ name, icon, value }, index) => {
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
return (
<button
type="button"
key={index}
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
currentLlm == name
? "bg-background-200"
: "bg-background hover:bg-background-100"
} text-left rounded`}
onClick={() => onSelect(value)}
>
{icon({ size: 16 })}
{getDisplayNameForModel(name)}
</button>
);
}
})}
</div>
);
};

View File

@@ -18,7 +18,7 @@ export default function ExceptionTraceModal({
title="Full Exception Trace"
onOutsideClick={onOutsideClick}
>
<div className="overflow-y-auto mb-6">
<div className="overflow-y-auto include-scrollbar pr-3 h-full mb-6">
<div className="mb-6">
{!copyClicked ? (
<div

View File

@@ -1,11 +1,8 @@
import { getXDaysAgo } from "@/lib/dateUtils";
import { DateRangePickerValue } from "@tremor/react";
import { FiCalendar, FiChevronDown, FiXCircle } from "react-icons/fi";
import { CustomDropdown, DefaultDropdownElement } from "../Dropdown";
export const LAST_30_DAYS = "Last 30 days";
export const LAST_7_DAYS = "Last 7 days";
export const TODAY = "Today";
import { CustomDropdown } from "../Dropdown";
import { timeRangeValues } from "@/app/config/timeRange";
import { TimeRangeSelector } from "@/components/filters/TimeRangeSelector";
export function DateRangeSelector({
value,
@@ -20,59 +17,23 @@ export function DateRangeSelector({
<div>
<CustomDropdown
dropdown={
<div
<TimeRangeSelector
value={value}
className={`
border
border-border
bg-background
rounded-lg
flex
flex-col
w-64
max-h-96
overflow-y-auto
flex
overscroll-contain`}
>
<DefaultDropdownElement
key={LAST_30_DAYS}
name={LAST_30_DAYS}
onSelect={() =>
onValueChange({
to: new Date(),
from: getXDaysAgo(30),
selectValue: LAST_30_DAYS,
})
}
isSelected={value?.selectValue === LAST_30_DAYS}
/>
<DefaultDropdownElement
key={LAST_7_DAYS}
name={LAST_7_DAYS}
onSelect={() =>
onValueChange({
to: new Date(),
from: getXDaysAgo(7),
selectValue: LAST_7_DAYS,
})
}
isSelected={value?.selectValue === LAST_7_DAYS}
/>
<DefaultDropdownElement
key={TODAY}
name={TODAY}
onSelect={() =>
onValueChange({
to: new Date(),
from: getXDaysAgo(1),
selectValue: TODAY,
})
}
isSelected={value?.selectValue === TODAY}
/>
</div>
border
border-border
bg-background
rounded-lg
flex
flex-col
w-64
max-h-96
overflow-y-auto
flex
overscroll-contain`}
timeRangeValues={timeRangeValues}
onValueChange={onValueChange}
/>
}
>
<div

View File

@@ -3,7 +3,6 @@
import {
DanswerDocument,
DocumentRelevance,
Relevance,
SearchDanswerDocument,
} from "@/lib/search/interfaces";
import { DocumentFeedbackBlock } from "./DocumentFeedbackBlock";
@@ -12,11 +11,10 @@ import { PopupSpec } from "../admin/connectors/Popup";
import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge";
import { SourceIcon } from "../SourceIcon";
import { MetadataBadge } from "../MetadataBadge";
import { BookIcon, CheckmarkIcon, LightBulbIcon, XIcon } from "../icons/icons";
import { BookIcon, LightBulbIcon } from "../icons/icons";
import { FaStar } from "react-icons/fa";
import { FiTag } from "react-icons/fi";
import { DISABLE_LLM_DOC_RELEVANCE } from "@/lib/constants";
import { SettingsContext } from "../settings/SettingsProvider";
import { CustomTooltip, TooltipGroup } from "../tooltip/CustomTooltip";
import { WarningCircle } from "@phosphor-icons/react";

View File

@@ -42,6 +42,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
if (!results[0].ok) {
if (results[0].status === 403) {
settings = {
gpu_enabled: false,
chat_page_enabled: true,
search_page_enabled: true,
default_page: "search",

View File

@@ -14,15 +14,20 @@ export function orderAssistantsForUser(
])
);
assistants = assistants.filter((assistant) =>
let filteredAssistants = assistants.filter((assistant) =>
chosenAssistantsSet.has(assistant.id)
);
assistants.sort((a, b) => {
if (filteredAssistants.length == 0) {
return assistants;
}
filteredAssistants.sort((a, b) => {
const orderA = assistantOrderMap.get(a.id) ?? Number.MAX_SAFE_INTEGER;
const orderB = assistantOrderMap.get(b.id) ?? Number.MAX_SAFE_INTEGER;
return orderA - orderB;
});
return filteredAssistants;
}
return assistants;

View File

@@ -873,6 +873,9 @@ export function createConnectorValidationSchema(
});
}
export const defaultPruneFreqDays = 30; // 30 days
export const defaultRefreshFreqMinutes = 30; // 30 minutes
// CONNECTORS
export interface ConnectorBase<T> {
name: string;

View File

@@ -5,6 +5,13 @@ export function getXDaysAgo(daysAgo: number) {
return daysAgoDate;
}
export function getXYearsAgo(yearsAgo: number) {
const today = new Date();
const yearsAgoDate = new Date(today);
yearsAgoDate.setFullYear(yearsAgoDate.getFullYear() - yearsAgo);
return yearsAgoDate;
}
export const timestampToDateString = (timestamp: string) => {
const date = new Date(timestamp);
const year = date.getFullYear();

View File

@@ -62,7 +62,7 @@ export function getLLMProviderOverrideForPersona(
return null;
}
const MODEL_NAMES_SUPPORTING_IMAGES = [
const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"gpt-4o",
"gpt-4o-mini",
"gpt-4-vision-preview",
@@ -84,8 +84,31 @@ const MODEL_NAMES_SUPPORTING_IMAGES = [
];
export function checkLLMSupportsImageInput(model: string) {
return MODEL_NAMES_SUPPORTING_IMAGES.some((modelName) => modelName === model);
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
(modelName) => modelName === model
);
}
const MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT = [
["openai", "gpt-4o"],
["openai", "gpt-4o-mini"],
["openai", "gpt-4-vision-preview"],
["openai", "gpt-4-turbo"],
["openai", "gpt-4-1106-vision-preview"],
["azure", "gpt-4o"],
["azure", "gpt-4o-mini"],
["azure", "gpt-4-vision-preview"],
["azure", "gpt-4-turbo"],
["azure", "gpt-4-1106-vision-preview"],
];
export function checkLLMSupportsImageOutput(provider: string, model: string) {
return MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT.some(
(modelProvider) =>
modelProvider[0] === provider && modelProvider[1] === model
);
}
export const structureValue = (
name: string,
provider: string,