mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-27 12:45:51 +00:00
Compare commits
21 Commits
final_sequ
...
v0.5.21
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e662e3b57d | ||
|
|
2073820e33 | ||
|
|
5f25b243c5 | ||
|
|
a9427f190a | ||
|
|
18fbe9d7e8 | ||
|
|
75c9b1cafe | ||
|
|
632a8f700b | ||
|
|
cd58c96014 | ||
|
|
c5032d25c9 | ||
|
|
72acde6fd4 | ||
|
|
5596a68d08 | ||
|
|
5b18409c89 | ||
|
|
84272af5ac | ||
|
|
6bef70c8b7 | ||
|
|
7f7559e3d2 | ||
|
|
7ba829a585 | ||
|
|
8b2ecb4eab | ||
|
|
2dd3870504 | ||
|
|
df464fc54b | ||
|
|
96b98fbc4a | ||
|
|
66cf67d04d |
@@ -0,0 +1,64 @@
|
||||
"""server default chosen assistants
|
||||
|
||||
Revision ID: 35e6853a51d5
|
||||
Revises: c99d76fcd298
|
||||
Create Date: 2024-09-13 13:20:32.885317
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e6853a51d5"
|
||||
down_revision = "c99d76fcd298"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
DEFAULT_ASSISTANTS = [-2, -1, 0]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Update any NULL values to the default value
|
||||
# This upgrades existing users without ordered assistant
|
||||
# to have default assistants set to visible assistants which are
|
||||
# accessible by them.
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user" u
|
||||
SET chosen_assistants = (
|
||||
SELECT jsonb_agg(
|
||||
p.id ORDER BY
|
||||
COALESCE(p.display_priority, 2147483647) ASC,
|
||||
p.id ASC
|
||||
)
|
||||
FROM persona p
|
||||
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
|
||||
WHERE p.is_visible = true
|
||||
AND (p.is_public = true OR pu.user_id IS NOT NULL)
|
||||
)
|
||||
WHERE chosen_assistants IS NULL
|
||||
OR chosen_assistants = 'null'
|
||||
OR jsonb_typeof(chosen_assistants) = 'null'
|
||||
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 2: Alter the column to make it non-nullable
|
||||
op.alter_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add nullable to persona id in Chat Session
|
||||
|
||||
Revision ID: c99d76fcd298
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-07-09 19:27:01.579697
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c99d76fcd298"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -19,7 +19,9 @@ def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -134,7 +134,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
@@ -189,7 +189,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
@@ -256,7 +256,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
@@ -27,7 +29,7 @@ if REDIS_SSL:
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
@@ -42,3 +44,33 @@ broker_transport_options = {
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
||||
# might be irrelevant
|
||||
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
|
||||
# Option 0: Defaults (json serializer, no compression)
|
||||
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
||||
|
||||
# Option 1: Reduces generator task result sizes by roughly 20%
|
||||
# task_compression = "bzip2"
|
||||
# task_serializer = "pickle"
|
||||
# result_compression = "bzip2"
|
||||
# result_serializer = "pickle"
|
||||
# accept_content=["pickle"]
|
||||
|
||||
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
|
||||
# can be large. small tasks change very little
|
||||
# def pickle_bz2_encoder(data):
|
||||
# return bz2.compress(pickle.dumps(data))
|
||||
|
||||
# def pickle_bz2_decoder(data):
|
||||
# return pickle.loads(bz2.decompress(data))
|
||||
|
||||
# from kombu import serialization # To register custom serialization with Celery/Kombu
|
||||
|
||||
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
|
||||
|
||||
# task_serializer = "pickle-bzip2"
|
||||
# result_serializer = "pickle-bzip2"
|
||||
# accept_content=["pickle", "pickle-bzip2"]
|
||||
|
||||
@@ -675,9 +675,11 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
dedupe_docs=(
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
),
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
@@ -743,10 +745,18 @@ def stream_chat_message_objects(
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
yield StreamingError(error=error_msg)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
@@ -786,16 +796,18 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else [],
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
|
||||
@@ -159,11 +159,18 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
# Used by celery as broker and backend
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15))
|
||||
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
|
||||
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
|
||||
)
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
|
||||
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "CERT_NONE")
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
# should be one of "required", "optional", or "none"
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
|
||||
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
@@ -153,15 +154,23 @@ def handle_regular_answer(
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
if new_message_request.persona_config:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Slack bot does not support persona config",
|
||||
)
|
||||
|
||||
elif new_message_request.persona_id:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
|
||||
@@ -226,7 +226,7 @@ def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
|
||||
@@ -4,9 +4,11 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import Tool as ToolModel
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@@ -103,6 +105,20 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_doc_sets(
|
||||
db_session: Session, doc_ids: list[int]
|
||||
) -> list[DocumentSet]:
|
||||
return list(
|
||||
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
|
||||
return list(
|
||||
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
|
||||
@@ -122,7 +122,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
chosen_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -866,7 +866,9 @@ class ChatSession(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -900,7 +902,6 @@ class ChatSession(Base):
|
||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||
PydanticType(PromptOverride), nullable=True
|
||||
)
|
||||
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -909,7 +910,6 @@ class ChatSession(Base):
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
folder: Mapped["ChatFolder"] = relationship(
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
@@ -1347,55 +1347,6 @@ class ChannelConfig(TypedDict):
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
@@ -1421,7 +1372,7 @@ class SlackBotConfig(Base):
|
||||
)
|
||||
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="slack_bot_configs",
|
||||
@@ -1651,6 +1602,55 @@ class TokenRateLimit__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Permission Sync"""
|
||||
|
||||
|
||||
|
||||
@@ -210,6 +210,22 @@ def update_persona_shared_users(
|
||||
)
|
||||
|
||||
|
||||
def update_persona_public_status(
|
||||
persona_id: int,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
) -> None:
|
||||
persona = fetch_persona_by_id(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise ValueError("You don't have permission to modify this persona")
|
||||
|
||||
persona.is_public = is_public
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_prompts(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
@@ -551,6 +567,7 @@ def update_persona_visibility(
|
||||
persona = fetch_persona_by_id(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.is_visible = is_visible
|
||||
db_session.commit()
|
||||
|
||||
@@ -563,13 +580,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
|
||||
prompts = db_session.scalars(
|
||||
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||
).all()
|
||||
|
||||
return prompts
|
||||
return list(prompts)
|
||||
|
||||
|
||||
def get_prompt_by_id(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -14,8 +15,11 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.errors import EERequiredError
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
def _build_persona_name(channel_names: list[str]) -> str:
|
||||
@@ -70,6 +74,10 @@ def create_slack_bot_persona(
|
||||
return persona
|
||||
|
||||
|
||||
def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def insert_slack_bot_config(
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
@@ -78,14 +86,29 @@ def insert_slack_bot_config(
|
||||
enable_auto_filters: bool,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
versioned_fetch_standard_answer_categories_by_ids = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.standard_answer",
|
||||
"fetch_standard_answer_categories_by_ids",
|
||||
_no_ee_standard_answer_categories,
|
||||
)
|
||||
)
|
||||
existing_standard_answer_categories = (
|
||||
versioned_fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
if len(existing_standard_answer_categories) == 0:
|
||||
raise EERequiredError(
|
||||
"Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
)
|
||||
|
||||
slack_bot_config = SlackBotConfig(
|
||||
persona_id=persona_id,
|
||||
@@ -117,9 +140,18 @@ def update_slack_bot_config(
|
||||
f"Unable to find slack bot config with ID {slack_bot_config_id}"
|
||||
)
|
||||
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
versioned_fetch_standard_answer_categories_by_ids = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.standard_answer",
|
||||
"fetch_standard_answer_categories_by_ids",
|
||||
_no_ee_standard_answer_categories,
|
||||
)
|
||||
)
|
||||
existing_standard_answer_categories = (
|
||||
versioned_fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import StandardAnswer
|
||||
from danswer.db.models import StandardAnswerCategory
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_category_validity(category_name: str) -> bool:
|
||||
"""If a category name is too long, it should not be used (it will cause an error in Postgres
|
||||
as the unique constraint can only apply to entries that are less than 2704 bytes).
|
||||
|
||||
Additionally, extremely long categories are not really usable / useful."""
|
||||
if len(category_name) > 255:
|
||||
logger.error(
|
||||
f"Category with name '{category_name}' is too long, cannot be used"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def insert_standard_answer_category(
|
||||
category_name: str, db_session: Session
|
||||
) -> StandardAnswerCategory:
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
standard_answer_category = StandardAnswerCategory(name=category_name)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def insert_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer = StandardAnswer(
|
||||
keyword=keyword,
|
||||
answer=answer,
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
match_regex=match_regex,
|
||||
match_any_keywords=match_any_keywords,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
return standard_answer
|
||||
|
||||
|
||||
def update_standard_answer(
|
||||
standard_answer_id: int,
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer.keyword = keyword
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
standard_answer.match_regex = match_regex
|
||||
standard_answer.match_any_keywords = match_any_keywords
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer
|
||||
|
||||
|
||||
def remove_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
standard_answer.active = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
category_name: str,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category = db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
if standard_answer_category is None:
|
||||
raise ValueError(
|
||||
f"No standard answer category with id {standard_answer_category_id}"
|
||||
)
|
||||
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
|
||||
standard_answer_category.name = category_name
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def fetch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id.in_(standard_answer_category_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories(
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(select(StandardAnswerCategory)).all()
|
||||
|
||||
|
||||
def fetch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswer).where(StandardAnswer.active.is_(True))
|
||||
).all()
|
||||
|
||||
|
||||
def create_initial_default_standard_answer_category(db_session: Session) -> None:
|
||||
default_category_id = 0
|
||||
default_category_name = "General"
|
||||
default_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=default_category_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if default_category is not None:
|
||||
if default_category.name != default_category_name:
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default standard answer category does not have expected name."
|
||||
)
|
||||
return
|
||||
|
||||
standard_answer_category = StandardAnswerCategory(
|
||||
id=default_category_id,
|
||||
name=default_category_name,
|
||||
)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
@@ -558,7 +558,15 @@ class Answer:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
yield cast(str, item)
|
||||
|
||||
# this should never happen, but we're seeing weird behavior here so handling for now
|
||||
if not isinstance(item, str):
|
||||
logger.error(
|
||||
f"Received non-string item in answer stream: {item}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
|
||||
@@ -62,7 +62,6 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
@@ -111,6 +110,8 @@ from danswer.server.query_and_chat.query_backend import (
|
||||
from danswer.server.query_and_chat.query_backend import basic_router as query_router
|
||||
from danswer.server.settings.api import admin_router as settings_admin_router
|
||||
from danswer.server.settings.api import basic_router as settings_router
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
@@ -125,10 +126,10 @@ from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -186,9 +187,6 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls()
|
||||
@@ -245,6 +243,12 @@ def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
)
|
||||
update_current_search_settings(db_session, updated_settings)
|
||||
|
||||
# Update settings with GPU availability
|
||||
settings = load_settings()
|
||||
settings.gpu_enabled = gpu_available
|
||||
store_settings(settings)
|
||||
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"Existing docs or connectors found. Skipping multipass indexing update."
|
||||
@@ -591,7 +595,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Change this to the list of allowed origins if needed
|
||||
allow_origins=CORS_ALLOWED_ORIGIN, # Configurable via environment variable
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
from danswer.llm.answering.answer import Answer
|
||||
@@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -118,7 +119,17 @@ def stream_answer_objects(
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
|
||||
|
||||
temporary_persona: Persona | None = None
|
||||
if query_req.persona_config is not None:
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session, persona_config=query_req.persona_config, user=user
|
||||
)
|
||||
temporary_persona = new_persona
|
||||
|
||||
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||
|
||||
llm, fast_llm = get_llms_for_persona(persona=persona)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
@@ -153,11 +164,11 @@ def stream_answer_objects(
|
||||
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
||||
)
|
||||
if prompt is None:
|
||||
if not chat_session.persona.prompts:
|
||||
if not persona.prompts:
|
||||
raise RuntimeError(
|
||||
"Persona does not have any prompts - this should never happen"
|
||||
)
|
||||
prompt = chat_session.persona.prompts[0]
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
@@ -174,9 +185,7 @@ def stream_answer_objects(
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
chat_session.persona.num_chunks
|
||||
if chat_session.persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
|
||||
),
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
@@ -187,16 +196,16 @@ def stream_answer_objects(
|
||||
evaluation_type=LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type,
|
||||
persona=chat_session.persona,
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
@@ -209,13 +218,15 @@ def stream_answer_objects(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
tools=[search_tool] if search_tool else [],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
@@ -223,9 +234,7 @@ def stream_answer_objects(
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
@@ -261,6 +270,7 @@ def stream_answer_objects(
|
||||
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
@@ -287,6 +297,7 @@ def stream_answer_objects(
|
||||
relevance_summary=evaluation_response,
|
||||
)
|
||||
yield evaluation_response
|
||||
|
||||
else:
|
||||
yield packet
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
@@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
@@ -23,10 +27,49 @@ class ThreadMessage(BaseModel):
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class DocumentSetConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DirectQARequest(ChunkContext):
|
||||
persona_config: PersonaConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
persona_id: int
|
||||
prompt_id: int | None = None
|
||||
multilingual_query_expansion: list[str] | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
@@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "DirectQARequest":
|
||||
if (self.persona_config is None) == (self.persona_id is None):
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||
if self.chain_of_thought and self.prompt_id is not None:
|
||||
|
||||
@@ -84,6 +84,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
# Multilingual Expansion
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
rerank_api_url=search_settings.rerank_api_url,
|
||||
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
@@ -20,6 +21,7 @@ from danswer.db.persona import get_personas
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import mark_persona_as_not_deleted
|
||||
from danswer.db.persona import update_all_personas_display_priority
|
||||
from danswer.db.persona import update_persona_public_status
|
||||
from danswer.db.persona import update_persona_shared_users
|
||||
from danswer.db.persona import update_persona_visibility
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
@@ -43,6 +45,10 @@ class IsVisibleRequest(BaseModel):
|
||||
is_visible: bool
|
||||
|
||||
|
||||
class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
def patch_persona_visibility(
|
||||
persona_id: int,
|
||||
@@ -58,6 +64,25 @@ def patch_persona_visibility(
|
||||
)
|
||||
|
||||
|
||||
@basic_router.patch("/{persona_id}/public")
|
||||
def patch_user_presona_public_status(
|
||||
persona_id: int,
|
||||
is_public_request: IsPublicRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_public_status(
|
||||
persona_id=persona_id,
|
||||
is_public=is_public_request.is_public,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona public status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.put("/display-priority")
|
||||
def patch_persona_display_priority(
|
||||
display_priority_request: DisplayPriorityRequest,
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -17,13 +15,12 @@ from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategory
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -119,95 +116,6 @@ class HiddenUpdateRequest(BaseModel):
|
||||
hidden: bool
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SlackBotTokens(BaseModel):
|
||||
bot_token: str
|
||||
app_token: str
|
||||
@@ -233,6 +141,7 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
# list of user emails
|
||||
follow_up_tags: list[str] | None = None
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[int] = Field(default_factory=list)
|
||||
|
||||
@field_validator("answer_filters", mode="before")
|
||||
@@ -257,6 +166,7 @@ class SlackBotConfig(BaseModel):
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
enable_auto_filters: bool
|
||||
|
||||
@@ -275,6 +185,7 @@ class SlackBotConfig(BaseModel):
|
||||
),
|
||||
channel_config=slack_bot_config_model.channel_config,
|
||||
response_type=slack_bot_config_model.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in slack_bot_config_model.standard_answer_categories
|
||||
|
||||
@@ -108,6 +108,7 @@ def create_slack_bot_config(
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_bot_config_creation_request.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
|
||||
|
||||
@@ -164,7 +164,7 @@ def get_chat_session(
|
||||
chat_session_id=session_id,
|
||||
description=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||
current_alternate_model=chat_session.current_alternate_model,
|
||||
messages=[
|
||||
translate_db_message_to_chat_message_detail(
|
||||
|
||||
@@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel):
|
||||
class ChatSessionDetails(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
persona_id: int
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
@@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel):
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: int
|
||||
description: str
|
||||
persona_id: int
|
||||
persona_name: str
|
||||
persona_id: int | None = None
|
||||
persona_name: str | None
|
||||
messages: list[ChatMessageDetail]
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
|
||||
@@ -37,6 +37,7 @@ class Settings(BaseModel):
|
||||
search_page_enabled: bool = True
|
||||
default_page: PageType = PageType.SEARCH
|
||||
maximum_chat_retention_days: int | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
|
||||
def check_validity(self) -> None:
|
||||
chat_page_enabled = self.chat_page_enabled
|
||||
|
||||
@@ -200,6 +200,7 @@ class ImageGenerationTool(Tool):
|
||||
revised_prompt=response.data[0]["revised_prompt"],
|
||||
url=response.data[0]["url"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error occured during image generation: {e}")
|
||||
|
||||
|
||||
3
backend/danswer/utils/errors.py
Normal file
3
backend/danswer/utils/errors.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class EERequiredError(Exception):
|
||||
"""This error is thrown if an Enterprise Edition feature or API is
|
||||
requested but the Enterprise Edition flag is not set."""
|
||||
@@ -21,11 +21,11 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from ee.danswer.db.standard_answer import find_matching_standard_answers
|
||||
from ee.danswer.server.manage.models import StandardAnswer as PydanticStandardAnswer
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -12,6 +12,198 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_category_validity(category_name: str) -> bool:
|
||||
"""If a category name is too long, it should not be used (it will cause an error in Postgres
|
||||
as the unique constraint can only apply to entries that are less than 2704 bytes).
|
||||
|
||||
Additionally, extremely long categories are not really usable / useful."""
|
||||
if len(category_name) > 255:
|
||||
logger.error(
|
||||
f"Category with name '{category_name}' is too long, cannot be used"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def insert_standard_answer_category(
|
||||
category_name: str, db_session: Session
|
||||
) -> StandardAnswerCategory:
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
standard_answer_category = StandardAnswerCategory(name=category_name)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def insert_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer = StandardAnswer(
|
||||
keyword=keyword,
|
||||
answer=answer,
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
match_regex=match_regex,
|
||||
match_any_keywords=match_any_keywords,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
return standard_answer
|
||||
|
||||
|
||||
def update_standard_answer(
|
||||
standard_answer_id: int,
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer.keyword = keyword
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
standard_answer.match_regex = match_regex
|
||||
standard_answer.match_any_keywords = match_any_keywords
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer
|
||||
|
||||
|
||||
def remove_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
standard_answer.active = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
category_name: str,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category = db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
if standard_answer_category is None:
|
||||
raise ValueError(
|
||||
f"No standard answer category with id {standard_answer_category_id}"
|
||||
)
|
||||
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
|
||||
standard_answer_category.name = category_name
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def fetch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id.in_(standard_answer_category_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories(
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(select(StandardAnswerCategory)).all()
|
||||
|
||||
|
||||
def fetch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswer).where(StandardAnswer.active.is_(True))
|
||||
).all()
|
||||
|
||||
|
||||
def create_initial_default_standard_answer_category(db_session: Session) -> None:
|
||||
default_category_id = 0
|
||||
default_category_name = "General"
|
||||
default_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=default_category_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if default_category is not None:
|
||||
if default_category.name != default_category_name:
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default standard answer category does not have expected name."
|
||||
)
|
||||
return
|
||||
|
||||
standard_answer_category = StandardAnswerCategory(
|
||||
id=default_category_id,
|
||||
name=default_category_name,
|
||||
)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_names(
|
||||
standard_answer_category_names: list[str],
|
||||
db_session: Session,
|
||||
|
||||
98
backend/ee/danswer/server/manage/models.py
Normal file
98
backend/ee/danswer/server/manage/models.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
||||
@@ -6,19 +6,19 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.standard_answer import fetch_standard_answer
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories
|
||||
from danswer.db.standard_answer import fetch_standard_answer_category
|
||||
from danswer.db.standard_answer import fetch_standard_answers
|
||||
from danswer.db.standard_answer import insert_standard_answer
|
||||
from danswer.db.standard_answer import insert_standard_answer_category
|
||||
from danswer.db.standard_answer import remove_standard_answer
|
||||
from danswer.db.standard_answer import update_standard_answer
|
||||
from danswer.db.standard_answer import update_standard_answer_category
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.server.manage.models import StandardAnswerCategory
|
||||
from danswer.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from danswer.server.manage.models import StandardAnswerCreationRequest
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_categories
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answer_category
|
||||
from ee.danswer.db.standard_answer import fetch_standard_answers
|
||||
from ee.danswer.db.standard_answer import insert_standard_answer
|
||||
from ee.danswer.db.standard_answer import insert_standard_answer_category
|
||||
from ee.danswer.db.standard_answer import remove_standard_answer
|
||||
from ee.danswer.db.standard_answer import update_standard_answer
|
||||
from ee.danswer.db.standard_answer import update_standard_answer_category
|
||||
from ee.danswer.server.manage.models import StandardAnswer
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategory
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from ee.danswer.server.manage.models import StandardAnswerCreationRequest
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.process_message import ChatPacketStream
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
@@ -28,6 +29,7 @@ from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
@@ -65,21 +67,64 @@ def _translate_doc_response_to_simple_doc(
|
||||
|
||||
def _get_final_context_doc_indices(
|
||||
final_context_docs: list[LlmDoc] | None,
|
||||
simple_search_docs: list[SimpleDoc] | None,
|
||||
top_docs: list[SavedSearchDoc] | None,
|
||||
) -> list[int] | None:
|
||||
"""
|
||||
this function returns a list of indices of the simple search docs
|
||||
that were actually fed to the LLM.
|
||||
"""
|
||||
if final_context_docs is None or simple_search_docs is None:
|
||||
if final_context_docs is None or top_docs is None:
|
||||
return None
|
||||
|
||||
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
|
||||
return [
|
||||
i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids
|
||||
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
|
||||
]
|
||||
|
||||
|
||||
def _convert_packet_stream_to_response(
|
||||
packets: ChatPacketStream,
|
||||
) -> ChatBasicResponse:
|
||||
response = ChatBasicResponse()
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.top_documents = packet.top_documents
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
|
||||
# TODO: deprecate `llm_chunks_indices`
|
||||
response.llm_chunks_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, FinalUsedContextDocsResponse):
|
||||
final_context_docs = packet.final_context_docs
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.cited_documents = {
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
|
||||
|
||||
@@ -139,36 +184,7 @@ def handle_simplified_chat_message(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
response = ChatBasicResponse()
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.message_id = packet.message_id
|
||||
elif isinstance(packet, FinalUsedContextDocsResponse):
|
||||
final_context_docs = packet.final_context_docs
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.cited_documents = {
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.simple_search_docs
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
return response
|
||||
return _convert_packet_stream_to_response(packets)
|
||||
|
||||
|
||||
@router.post("/send-message-simple-with-history")
|
||||
@@ -287,35 +303,4 @@ def handle_send_message_simple_with_history(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
response = ChatBasicResponse()
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, FinalUsedContextDocsResponse):
|
||||
final_context_docs = packet.final_context_docs
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.cited_documents = {
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.simple_search_docs
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
return response
|
||||
return _convert_packet_stream_to_response(packets)
|
||||
|
||||
@@ -8,7 +8,8 @@ from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from ee.danswer.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
@@ -73,10 +74,17 @@ class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
answer_citationless: str | None = None
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
|
||||
top_documents: list[SavedSearchDoc] | None = None
|
||||
|
||||
error_msg: str | None = None
|
||||
message_id: int | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
final_context_doc_indices: list[int] | None = None
|
||||
# this is a map of the citation number to the document id
|
||||
cited_documents: dict[int, str] | None = None
|
||||
|
||||
# FOR BACKWARDS COMPATIBILITY
|
||||
# TODO: deprecate both of these
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
@@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -133,12 +134,23 @@ def get_answer_with_quote(
|
||||
query = query_request.messages[0].message
|
||||
logger.notice(f"Received query for one shot answer API with quotes: {query}")
|
||||
|
||||
persona = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
if query_request.persona_config is not None:
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=query_request.persona_config,
|
||||
user=user,
|
||||
)
|
||||
persona = new_persona
|
||||
|
||||
elif query_request.persona_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
llm = get_main_llm_from_tuple(
|
||||
get_default_llms() if not persona else get_llms_for_persona(persona)
|
||||
|
||||
83
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
83
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.one_shot_answer.models import PersonaConfig
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
@@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel):
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None
|
||||
|
||||
@@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel):
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
time_created: datetime
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
|
||||
@@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str]:
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
@@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal(
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name
|
||||
if chat_session.persona
|
||||
else None,
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=feedback_type,
|
||||
)
|
||||
@@ -300,7 +302,7 @@ def snapshot_from_chat_session(
|
||||
for message in messages
|
||||
if message.message_type != MessageType.SYSTEM
|
||||
],
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||
time_created=chat_session.time_created,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.models import Settings
|
||||
from danswer.server.settings.store import store_settings as store_base_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.standard_answer import (
|
||||
create_initial_default_standard_answer_category,
|
||||
)
|
||||
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.danswer.server.enterprise_settings.store import store_analytics_script
|
||||
@@ -21,6 +24,7 @@ from ee.danswer.server.enterprise_settings.store import (
|
||||
)
|
||||
from ee.danswer.server.enterprise_settings.store import upload_logo
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
|
||||
@@ -146,3 +150,6 @@ def seed_db() -> None:
|
||||
_seed_logo(db_session, seed_config.seeded_logo_path)
|
||||
_seed_enterprise_settings(seed_config)
|
||||
_seed_analytics_script(seed_config)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# Used for logging
|
||||
SLACK_CHANNEL_ID = "channel_id"
|
||||
@@ -73,3 +74,18 @@ PRESERVED_SEARCH_FIELDS = [
|
||||
"passage_prefix",
|
||||
"query_prefix",
|
||||
]
|
||||
|
||||
|
||||
# CORS
|
||||
def validate_cors_origin(origin: str) -> None:
|
||||
parsed = urlparse(origin)
|
||||
if parsed.scheme not in ["http", "https"] or not parsed.netloc:
|
||||
raise ValueError(f"Invalid CORS origin: '{origin}'")
|
||||
|
||||
|
||||
CORS_ALLOWED_ORIGIN = os.environ.get("CORS_ALLOWED_ORIGIN", "*").split(",") or ["*"]
|
||||
|
||||
# Validate non-wildcard origins
|
||||
for origin in CORS_ALLOWED_ORIGIN:
|
||||
if origin != "*" and (stripped_origin := origin.strip()):
|
||||
validate_cors_origin(stripped_origin)
|
||||
|
||||
@@ -12,6 +12,7 @@ command=python danswer/background/update.py
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
|
||||
# Background jobs that must be run async due to long time to completion
|
||||
# NOTE: due to an issue with Celery + SQLAlchemy
|
||||
# (https://github.com/celery/celery/issues/7007#issuecomment-1740139367)
|
||||
@@ -37,11 +38,9 @@ autorestart=true
|
||||
# Job scheduler for periodic tasks
|
||||
[program:celery_beat]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app beat
|
||||
--loglevel=INFO
|
||||
--logfile=/var/log/celery_beat_supervisor.log
|
||||
environment=LOG_FILE_NAME=celery_beat
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
# Listens for Slack messages and responds with answers
|
||||
# for all channels that the DanswerBot has been added to.
|
||||
@@ -68,4 +67,4 @@ command=tail -qF
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
autorestart=true
|
||||
@@ -51,6 +51,7 @@ def test_send_message_simple_with_history(reset: None) -> None:
|
||||
|
||||
# Check that the top document is the correct document
|
||||
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
|
||||
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
|
||||
|
||||
# assert that the metadata is correct
|
||||
for doc in cc_pair_1.documents:
|
||||
|
||||
@@ -34,6 +34,7 @@ services:
|
||||
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}
|
||||
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
|
||||
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
|
||||
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
|
||||
# Gen AI Settings
|
||||
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
|
||||
@@ -31,6 +31,7 @@ services:
|
||||
- SMTP_PASS=${SMTP_PASS:-}
|
||||
- EMAIL_FROM=${EMAIL_FROM:-}
|
||||
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
|
||||
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
|
||||
# Gen AI Settings
|
||||
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
|
||||
@@ -27,4 +27,4 @@ spec:
|
||||
key: redis_password
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
name: env-configmap
|
||||
@@ -13,6 +13,7 @@ data:
|
||||
SMTP_USER: "" # 'your-email@company.com'
|
||||
SMTP_PASS: "" # 'your-gmail-password'
|
||||
EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead
|
||||
CORS_ALLOWED_ORIGIN: ""
|
||||
# Gen AI Settings
|
||||
GEN_AI_MAX_TOKENS: ""
|
||||
QA_TIMEOUT: "60"
|
||||
|
||||
@@ -25,10 +25,8 @@ import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||
import { Option } from "@/components/Dropdown";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
|
||||
import { checkLLMSupportsImageOutput, destructureValue } from "@/lib/llm/utils";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { checkUserIsNoAuthUser } from "@/lib/user";
|
||||
|
||||
@@ -47,7 +45,12 @@ import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import CollapsibleSection from "./CollapsibleSection";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||
import { Persona, StarterMessage } from "./interfaces";
|
||||
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
|
||||
import {
|
||||
buildFinalPrompt,
|
||||
createPersona,
|
||||
providersContainImageGeneratingSupport,
|
||||
updatePersona,
|
||||
} from "./lib";
|
||||
import { Popover } from "@/components/popover/Popover";
|
||||
import {
|
||||
CameraIcon,
|
||||
@@ -167,7 +170,7 @@ export function AssistantEditor({
|
||||
const defaultProvider = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
);
|
||||
|
||||
const defaultProviderName = defaultProvider?.provider;
|
||||
const defaultModelName = defaultProvider?.default_model_name;
|
||||
const providerDisplayNameToProviderName = new Map<string, string>();
|
||||
llmProviders.forEach((llmProvider) => {
|
||||
@@ -187,10 +190,9 @@ export function AssistantEditor({
|
||||
});
|
||||
modelOptionsByProvider.set(llmProvider.name, providerOptions);
|
||||
});
|
||||
const providerSupportingImageGenerationExists = llmProviders.some(
|
||||
(provider) =>
|
||||
provider.provider === "openai" || provider.provider === "anthropic"
|
||||
);
|
||||
|
||||
const providerSupportingImageGenerationExists =
|
||||
providersContainImageGeneratingSupport(llmProviders);
|
||||
|
||||
const personaCurrentToolIds =
|
||||
existingPersona?.tools.map((tool) => tool.id) || [];
|
||||
@@ -342,7 +344,12 @@ export function AssistantEditor({
|
||||
|
||||
if (imageGenerationToolEnabled) {
|
||||
if (
|
||||
!checkLLMSupportsImageInput(
|
||||
!checkLLMSupportsImageOutput(
|
||||
providerDisplayNameToProviderName.get(
|
||||
values.llm_model_provider_override || ""
|
||||
) ||
|
||||
defaultProviderName ||
|
||||
"",
|
||||
values.llm_model_version_override || defaultModelName || ""
|
||||
)
|
||||
) {
|
||||
@@ -453,6 +460,15 @@ export function AssistantEditor({
|
||||
: false;
|
||||
}
|
||||
|
||||
const currentLLMSupportsImageOutput = checkLLMSupportsImageOutput(
|
||||
providerDisplayNameToProviderName.get(
|
||||
values.llm_model_provider_override || ""
|
||||
) ||
|
||||
defaultProviderName ||
|
||||
"",
|
||||
values.llm_model_version_override || defaultModelName || ""
|
||||
);
|
||||
|
||||
return (
|
||||
<Form className="w-full text-text-950">
|
||||
<div className="w-full flex gap-x-2 justify-center">
|
||||
@@ -757,9 +773,7 @@ export function AssistantEditor({
|
||||
<TooltipTrigger asChild>
|
||||
<div
|
||||
className={`w-fit ${
|
||||
!checkLLMSupportsImageInput(
|
||||
values.llm_model_version_override || ""
|
||||
)
|
||||
!currentLLMSupportsImageOutput
|
||||
? "opacity-70 cursor-not-allowed"
|
||||
: ""
|
||||
}`}
|
||||
@@ -771,17 +785,11 @@ export function AssistantEditor({
|
||||
onChange={() => {
|
||||
toggleToolInValues(imageGenerationTool.id);
|
||||
}}
|
||||
disabled={
|
||||
!checkLLMSupportsImageInput(
|
||||
values.llm_model_version_override || ""
|
||||
)
|
||||
}
|
||||
disabled={!currentLLMSupportsImageOutput}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{!checkLLMSupportsImageInput(
|
||||
values.llm_model_version_override || ""
|
||||
) && (
|
||||
{!currentLLMSupportsImageOutput && (
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
To use Image Generation, select GPT-4o or another
|
||||
@@ -1051,15 +1059,15 @@ export function AssistantEditor({
|
||||
<Field
|
||||
name={`starter_messages[${index}].name`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
@@ -1081,15 +1089,15 @@ export function AssistantEditor({
|
||||
<Field
|
||||
name={`starter_messages.${index}.description`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
@@ -1112,15 +1120,15 @@ export function AssistantEditor({
|
||||
<Field
|
||||
name={`starter_messages[${index}].message`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
as="textarea"
|
||||
autoComplete="off"
|
||||
/>
|
||||
|
||||
@@ -8,7 +8,11 @@ import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import { UniqueIdentifier } from "@dnd-kit/core";
|
||||
import { DraggableTable } from "@/components/table/DraggableTable";
|
||||
import { deletePersona, personaComparator } from "./lib";
|
||||
import {
|
||||
deletePersona,
|
||||
personaComparator,
|
||||
togglePersonaVisibility,
|
||||
} from "./lib";
|
||||
import { FiEdit2 } from "react-icons/fi";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { getCurrentUser } from "@/lib/user";
|
||||
@@ -31,22 +35,6 @@ function PersonaTypeDisplay({ persona }: { persona: Persona }) {
|
||||
return <Text>Personal {persona.owner && <>({persona.owner.email})</>}</Text>;
|
||||
}
|
||||
|
||||
const togglePersonaVisibility = async (
|
||||
personaId: number,
|
||||
isVisible: boolean
|
||||
) => {
|
||||
const response = await fetch(`/api/admin/persona/${personaId}/visible`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
is_visible: !isVisible,
|
||||
}),
|
||||
});
|
||||
return response;
|
||||
};
|
||||
|
||||
export function PersonasTable({
|
||||
allPersonas,
|
||||
editablePersonas,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import { Persona, Prompt, StarterMessage } from "./interfaces";
|
||||
|
||||
interface PersonaCreationRequest {
|
||||
@@ -318,3 +319,50 @@ export function personaComparator(a: Persona, b: Persona) {
|
||||
|
||||
return closerToZeroNegativesFirstComparator(a.id, b.id);
|
||||
}
|
||||
|
||||
export const togglePersonaVisibility = async (
|
||||
personaId: number,
|
||||
isVisible: boolean
|
||||
) => {
|
||||
const response = await fetch(`/api/persona/${personaId}/visible`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
is_visible: !isVisible,
|
||||
}),
|
||||
});
|
||||
return response;
|
||||
};
|
||||
|
||||
export const togglePersonaPublicStatus = async (
|
||||
personaId: number,
|
||||
isPublic: boolean
|
||||
) => {
|
||||
const response = await fetch(`/api/persona/${personaId}/public`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
is_public: isPublic,
|
||||
}),
|
||||
});
|
||||
return response;
|
||||
};
|
||||
|
||||
export function checkPersonaRequiresImageGeneration(persona: Persona) {
|
||||
for (const tool of persona.tools) {
|
||||
if (tool.name === "ImageGenerationTool") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export function providersContainImageGeneratingSupport(
|
||||
providers: FullLLMProvider[]
|
||||
) {
|
||||
return providers.some((provider) => provider.provider === "openai");
|
||||
}
|
||||
|
||||
@@ -7,10 +7,8 @@ import { Button, Card, Text, Title } from "@tremor/react";
|
||||
import useSWR from "swr";
|
||||
import { ModelPreview } from "../../../../components/embedding/ModelSelector";
|
||||
import {
|
||||
AVAILABLE_CLOUD_PROVIDERS,
|
||||
HostedEmbeddingModel,
|
||||
CloudEmbeddingModel,
|
||||
AVAILABLE_MODELS,
|
||||
} from "@/components/embedding/interfaces";
|
||||
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
@@ -24,10 +22,7 @@ export interface EmbeddingDetails {
|
||||
import { EmbeddingIcon } from "@/components/icons/icons";
|
||||
|
||||
import Link from "next/link";
|
||||
import {
|
||||
getCurrentModelCopy,
|
||||
SavedSearchSettings,
|
||||
} from "../../embeddings/interfaces";
|
||||
import { SavedSearchSettings } from "../../embeddings/interfaces";
|
||||
import UpgradingPage from "./UpgradingPage";
|
||||
import { useContext } from "react";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
|
||||
@@ -27,6 +27,8 @@ import {
|
||||
connectorConfigs,
|
||||
createConnectorInitialValues,
|
||||
createConnectorValidationSchema,
|
||||
defaultPruneFreqDays,
|
||||
defaultRefreshFreqMinutes,
|
||||
} from "@/lib/connectors/connectors";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import GDriveMain from "./pages/gdrive/GoogleDrivePage";
|
||||
@@ -154,7 +156,6 @@ export default function AddConnector({
|
||||
initialValues={createConnectorInitialValues(connector)}
|
||||
validationSchema={createConnectorValidationSchema(connector)}
|
||||
onSubmit={async (values) => {
|
||||
console.log(" Iam submiing the connector");
|
||||
const {
|
||||
name,
|
||||
groups,
|
||||
@@ -189,9 +190,9 @@ export default function AddConnector({
|
||||
|
||||
// Apply advanced configuration-specific transforms.
|
||||
const advancedConfiguration: any = {
|
||||
pruneFreq: pruneFreq * 60 * 60 * 24,
|
||||
pruneFreq: (pruneFreq || defaultPruneFreqDays) * 60 * 60 * 24,
|
||||
indexingStart: convertStringToDateTime(indexingStart),
|
||||
refreshFreq: refreshFreq * 60,
|
||||
refreshFreq: (refreshFreq || defaultRefreshFreqMinutes) * 60,
|
||||
};
|
||||
|
||||
// Google sites-specific handling
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import React, { Dispatch, forwardRef, SetStateAction, useState } from "react";
|
||||
import React, {
|
||||
Dispatch,
|
||||
forwardRef,
|
||||
SetStateAction,
|
||||
useContext,
|
||||
useState,
|
||||
} from "react";
|
||||
import { Formik, Form, FormikProps } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
RerankerProvider,
|
||||
RerankingDetails,
|
||||
RerankingModel,
|
||||
rerankingModels,
|
||||
} from "./interfaces";
|
||||
import { FiExternalLink } from "react-icons/fi";
|
||||
@@ -15,6 +22,7 @@ import {
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button } from "@tremor/react";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
|
||||
interface RerankingDetailsFormProps {
|
||||
setRerankingDetails: Dispatch<SetStateAction<RerankingDetails>>;
|
||||
@@ -38,10 +46,15 @@ const RerankingDetailsForm = forwardRef<
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
const [showGpuWarningModalModel, setShowGpuWarningModalModel] =
|
||||
useState<RerankingModel | null>(null);
|
||||
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
|
||||
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
|
||||
useState(false);
|
||||
|
||||
const combinedSettings = useContext(SettingsContext);
|
||||
const gpuEnabled = combinedSettings?.settings.gpu_enabled;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
innerRef={ref}
|
||||
@@ -169,6 +182,11 @@ const RerankingDetailsForm = forwardRef<
|
||||
RerankerProvider.LITELLM
|
||||
) {
|
||||
setShowLiteLLMConfigurationModal(true);
|
||||
} else if (
|
||||
!card.rerank_provider_type &&
|
||||
!gpuEnabled
|
||||
) {
|
||||
setShowGpuWarningModalModel(card);
|
||||
}
|
||||
|
||||
if (!isSelected) {
|
||||
@@ -225,6 +243,33 @@ const RerankingDetailsForm = forwardRef<
|
||||
})}
|
||||
</div>
|
||||
|
||||
{showGpuWarningModalModel && (
|
||||
<Modal
|
||||
onOutsideClick={() => setShowGpuWarningModalModel(null)}
|
||||
width="w-[500px] flex flex-col"
|
||||
title="GPU Not Enabled"
|
||||
>
|
||||
<>
|
||||
<p className="text-error font-semibold">Warning:</p>
|
||||
<p>
|
||||
Local reranking models require significant computational
|
||||
resources and may perform slowly without GPU
|
||||
acceleration. Consider switching to GPU-enabled
|
||||
infrastructure or using a cloud-based alternative for
|
||||
better performance.
|
||||
</p>
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
onClick={() => setShowGpuWarningModalModel(null)}
|
||||
color="blue"
|
||||
size="xs"
|
||||
>
|
||||
Understood
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
</Modal>
|
||||
)}
|
||||
{showLiteLLMConfigurationModal && (
|
||||
<Modal
|
||||
onOutsideClick={() => {
|
||||
|
||||
@@ -5,6 +5,7 @@ export interface Settings {
|
||||
maximum_chat_retention_days: number | null;
|
||||
notifications: Notification[];
|
||||
needs_reindexing: boolean;
|
||||
gpu_enabled: boolean;
|
||||
}
|
||||
|
||||
export interface Notification {
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
import { User } from "@/lib/types";
|
||||
import { Persona } from "../admin/assistants/interfaces";
|
||||
import { checkUserOwnsAssistant } from "@/lib/assistants/checkOwnership";
|
||||
import {
|
||||
FiImage,
|
||||
FiLock,
|
||||
FiMoreHorizontal,
|
||||
FiSearch,
|
||||
FiUnlock,
|
||||
} from "react-icons/fi";
|
||||
import { CustomTooltip } from "@/components/tooltip/CustomTooltip";
|
||||
import { FiLock, FiUnlock } from "react-icons/fi";
|
||||
|
||||
export function AssistantSharedStatusDisplay({
|
||||
assistant,
|
||||
@@ -56,20 +49,6 @@ export function AssistantSharedStatusDisplay({
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<div className="relative mt-4 text-xs flex text-subtle">
|
||||
<span className="font-medium">Powers:</span>{" "}
|
||||
{assistant.tools.length == 0 ? (
|
||||
<p className="ml-2">None</p>
|
||||
) : (
|
||||
assistant.tools.map((tool, ind) => {
|
||||
if (tool.name === "SearchTool") {
|
||||
return <FiSearch key={ind} className="ml-1 h-3 w-3 my-auto" />;
|
||||
} else if (tool.name === "ImageGenerationTool") {
|
||||
return <FiImage key={ind} className="ml-1 h-3 w-3 my-auto" />;
|
||||
}
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,11 +6,15 @@ import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { Divider, Text } from "@tremor/react";
|
||||
import {
|
||||
FiEdit2,
|
||||
FiFigma,
|
||||
FiMenu,
|
||||
FiMinus,
|
||||
FiMoreHorizontal,
|
||||
FiPlus,
|
||||
FiSearch,
|
||||
FiShare,
|
||||
FiShare2,
|
||||
FiToggleLeft,
|
||||
FiTrash,
|
||||
FiX,
|
||||
} from "react-icons/fi";
|
||||
@@ -50,11 +54,14 @@ import {
|
||||
verticalListSortingStrategy,
|
||||
} from "@dnd-kit/sortable";
|
||||
import { useSortable } from "@dnd-kit/sortable";
|
||||
import { CSS } from "@dnd-kit/utilities";
|
||||
|
||||
import { DragHandle } from "@/components/table/DragHandle";
|
||||
import { deletePersona } from "@/app/admin/assistants/lib";
|
||||
import {
|
||||
deletePersona,
|
||||
togglePersonaPublicStatus,
|
||||
} from "@/app/admin/assistants/lib";
|
||||
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
|
||||
import { MakePublicAssistantModal } from "@/app/chat/modal/MakePublicAssistantModal";
|
||||
|
||||
function DraggableAssistantListItem(props: any) {
|
||||
const {
|
||||
@@ -81,7 +88,7 @@ function DraggableAssistantListItem(props: any) {
|
||||
<DragHandle />
|
||||
</div>
|
||||
<div className="flex-grow">
|
||||
<AssistantListItem del {...props} />
|
||||
<AssistantListItem {...props} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -95,6 +102,7 @@ function AssistantListItem({
|
||||
isVisible,
|
||||
setPopup,
|
||||
deleteAssistant,
|
||||
shareAssistant,
|
||||
}: {
|
||||
assistant: Persona;
|
||||
user: User | null;
|
||||
@@ -102,7 +110,7 @@ function AssistantListItem({
|
||||
allAssistantIds: string[];
|
||||
isVisible: boolean;
|
||||
deleteAssistant: Dispatch<SetStateAction<Persona | null>>;
|
||||
|
||||
shareAssistant: Dispatch<SetStateAction<Persona | null>>;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
@@ -258,6 +266,18 @@ function AssistantListItem({
|
||||
) : (
|
||||
<></>
|
||||
),
|
||||
isOwnedByUser ? (
|
||||
<div
|
||||
key="delete"
|
||||
className="flex items-center gap-x-2"
|
||||
onClick={() => shareAssistant(assistant)}
|
||||
>
|
||||
{assistant.is_public ? <FiMinus /> : <FiPlus />} Make{" "}
|
||||
{assistant.is_public ? "Private" : "Public"}
|
||||
</div>
|
||||
) : (
|
||||
<></>
|
||||
),
|
||||
]}
|
||||
</DefaultPopover>
|
||||
</div>
|
||||
@@ -286,10 +306,15 @@ export function AssistantsList({
|
||||
user?.preferences?.chosen_assistants &&
|
||||
!user?.preferences?.chosen_assistants?.includes(assistant.id)
|
||||
);
|
||||
|
||||
const allAssistantIds = assistants.map((assistant) =>
|
||||
assistant.id.toString()
|
||||
);
|
||||
|
||||
const [deletingPersona, setDeletingPersona] = useState<Persona | null>(null);
|
||||
const [makePublicPersona, setMakePublicPersona] = useState<Persona | null>(
|
||||
null
|
||||
);
|
||||
|
||||
const { popup, setPopup } = usePopup();
|
||||
const router = useRouter();
|
||||
@@ -307,7 +332,7 @@ export function AssistantsList({
|
||||
|
||||
async function handleDragEnd(event: DragEndEvent) {
|
||||
const { active, over } = event;
|
||||
|
||||
filteredAssistants;
|
||||
if (over && active.id !== over.id) {
|
||||
setFilteredAssistants((assistants) => {
|
||||
const oldIndex = assistants.findIndex(
|
||||
@@ -351,6 +376,20 @@ export function AssistantsList({
|
||||
/>
|
||||
)}
|
||||
|
||||
{makePublicPersona && (
|
||||
<MakePublicAssistantModal
|
||||
isPublic={makePublicPersona.is_public}
|
||||
onClose={() => setMakePublicPersona(null)}
|
||||
onShare={async (newPublicStatus: boolean) => {
|
||||
await togglePersonaPublicStatus(
|
||||
makePublicPersona.id,
|
||||
newPublicStatus
|
||||
);
|
||||
router.refresh();
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="mx-auto mobile:w-[90%] desktop:w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar">
|
||||
<AssistantsPageTitle>My Assistants</AssistantsPageTitle>
|
||||
|
||||
@@ -403,6 +442,7 @@ export function AssistantsList({
|
||||
{filteredAssistants.map((assistant, index) => (
|
||||
<DraggableAssistantListItem
|
||||
deleteAssistant={setDeletingPersona}
|
||||
shareAssistant={setMakePublicPersona}
|
||||
key={assistant.id}
|
||||
assistant={assistant}
|
||||
user={user}
|
||||
@@ -431,6 +471,7 @@ export function AssistantsList({
|
||||
{ownedButHiddenAssistants.map((assistant, index) => (
|
||||
<AssistantListItem
|
||||
deleteAssistant={setDeletingPersona}
|
||||
shareAssistant={setMakePublicPersona}
|
||||
key={assistant.id}
|
||||
assistant={assistant}
|
||||
user={user}
|
||||
|
||||
@@ -426,7 +426,7 @@ export function ChatPage({
|
||||
}, [existingChatSessionId]);
|
||||
|
||||
const [message, setMessage] = useState(
|
||||
searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || ""
|
||||
searchParams.get(SEARCH_PARAM_NAMES.USER_PROMPT) || ""
|
||||
);
|
||||
|
||||
const [completeMessageDetail, setCompleteMessageDetail] = useState<
|
||||
|
||||
@@ -549,6 +549,7 @@ export function ChatInputBar({
|
||||
tab
|
||||
content={(close, ref) => (
|
||||
<LlmTab
|
||||
currentAssistant={alternativeAssistant || selectedAssistant}
|
||||
openModelSettings={openModelSettings}
|
||||
currentLlm={
|
||||
llmOverrideManager.llmOverride.modelName ||
|
||||
|
||||
@@ -582,7 +582,7 @@ export function personaIncludesImage(selectedPersona: Persona) {
|
||||
|
||||
const PARAMS_TO_SKIP = [
|
||||
SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD,
|
||||
SEARCH_PARAM_NAMES.USER_MESSAGE,
|
||||
SEARCH_PARAM_NAMES.USER_PROMPT,
|
||||
SEARCH_PARAM_NAMES.TITLE,
|
||||
// only use these if explicitly passed in
|
||||
SEARCH_PARAM_NAMES.CHAT_ID,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import React from "react";
|
||||
import { useState, ReactNode } from "react";
|
||||
import React, { useState, ReactNode, useCallback, useMemo, memo } from "react";
|
||||
import { FiCheck, FiCopy } from "react-icons/fi";
|
||||
|
||||
const CODE_BLOCK_PADDING_TYPE = { padding: "1rem" };
|
||||
@@ -11,21 +10,109 @@ interface CodeBlockProps {
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
export function CodeBlock({
|
||||
export const CodeBlock = memo(function CodeBlock({
|
||||
className = "",
|
||||
children,
|
||||
content,
|
||||
...props
|
||||
}: CodeBlockProps) {
|
||||
const language = className
|
||||
.split(" ")
|
||||
.filter((cls) => cls.startsWith("language-"))
|
||||
.map((cls) => cls.replace("language-", ""))
|
||||
.join(" ");
|
||||
const [copied, setCopied] = useState(false);
|
||||
|
||||
const language = useMemo(() => {
|
||||
return className
|
||||
.split(" ")
|
||||
.filter((cls) => cls.startsWith("language-"))
|
||||
.map((cls) => cls.replace("language-", ""))
|
||||
.join(" ");
|
||||
}, [className]);
|
||||
|
||||
const codeText = useMemo(() => {
|
||||
let codeText: string | null = null;
|
||||
if (
|
||||
props.node?.position?.start?.offset &&
|
||||
props.node?.position?.end?.offset
|
||||
) {
|
||||
codeText = content.slice(
|
||||
props.node.position.start.offset,
|
||||
props.node.position.end.offset
|
||||
);
|
||||
codeText = codeText.trim();
|
||||
|
||||
// Find the last occurrence of closing backticks
|
||||
const lastBackticksIndex = codeText.lastIndexOf("```");
|
||||
if (lastBackticksIndex !== -1) {
|
||||
codeText = codeText.slice(0, lastBackticksIndex + 3);
|
||||
}
|
||||
|
||||
// Remove the language declaration and trailing backticks
|
||||
const codeLines = codeText.split("\n");
|
||||
if (
|
||||
codeLines.length > 1 &&
|
||||
(codeLines[0].startsWith("```") ||
|
||||
codeLines[0].trim().startsWith("```"))
|
||||
) {
|
||||
codeLines.shift(); // Remove the first line with the language declaration
|
||||
if (
|
||||
codeLines[codeLines.length - 1] === "```" ||
|
||||
codeLines[codeLines.length - 1]?.trim() === "```"
|
||||
) {
|
||||
codeLines.pop(); // Remove the last line with the trailing backticks
|
||||
}
|
||||
|
||||
const minIndent = codeLines
|
||||
.filter((line) => line.trim().length > 0)
|
||||
.reduce((min, line) => {
|
||||
const match = line.match(/^\s*/);
|
||||
return Math.min(min, match ? match[0].length : 0);
|
||||
}, Infinity);
|
||||
|
||||
const formattedCodeLines = codeLines.map((line) =>
|
||||
line.slice(minIndent)
|
||||
);
|
||||
codeText = formattedCodeLines.join("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// handle unknown languages. They won't have a `node.position.start.offset`
|
||||
if (!codeText) {
|
||||
const findTextNode = (node: any): string | null => {
|
||||
if (node.type === "text") {
|
||||
return node.value;
|
||||
}
|
||||
let finalResult = "";
|
||||
if (node.children) {
|
||||
for (const child of node.children) {
|
||||
const result = findTextNode(child);
|
||||
if (result) {
|
||||
finalResult += result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return finalResult;
|
||||
};
|
||||
|
||||
codeText = findTextNode(props.node);
|
||||
}
|
||||
|
||||
return codeText;
|
||||
}, [content, props.node]);
|
||||
|
||||
const handleCopy = useCallback(
|
||||
(event: React.MouseEvent) => {
|
||||
event.preventDefault();
|
||||
if (!codeText) {
|
||||
return;
|
||||
}
|
||||
|
||||
navigator.clipboard.writeText(codeText).then(() => {
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000);
|
||||
});
|
||||
},
|
||||
[codeText]
|
||||
);
|
||||
|
||||
if (!language) {
|
||||
// this is the case of a single "`" e.g. `hi`
|
||||
if (typeof children === "string") {
|
||||
return <code className={className}>{children}</code>;
|
||||
}
|
||||
@@ -39,82 +126,6 @@ export function CodeBlock({
|
||||
);
|
||||
}
|
||||
|
||||
let codeText: string | null = null;
|
||||
if (
|
||||
props.node?.position?.start?.offset &&
|
||||
props.node?.position?.end?.offset
|
||||
) {
|
||||
codeText = content.slice(
|
||||
props.node.position.start.offset,
|
||||
props.node.position.end.offset
|
||||
);
|
||||
codeText = codeText.trim();
|
||||
|
||||
// Find the last occurrence of closing backticks
|
||||
const lastBackticksIndex = codeText.lastIndexOf("```");
|
||||
if (lastBackticksIndex !== -1) {
|
||||
codeText = codeText.slice(0, lastBackticksIndex + 3);
|
||||
}
|
||||
|
||||
// Remove the language declaration and trailing backticks
|
||||
const codeLines = codeText.split("\n");
|
||||
if (
|
||||
codeLines.length > 1 &&
|
||||
(codeLines[0].startsWith("```") || codeLines[0].trim().startsWith("```"))
|
||||
) {
|
||||
codeLines.shift(); // Remove the first line with the language declaration
|
||||
if (
|
||||
codeLines[codeLines.length - 1] === "```" ||
|
||||
codeLines[codeLines.length - 1]?.trim() === "```"
|
||||
) {
|
||||
codeLines.pop(); // Remove the last line with the trailing backticks
|
||||
}
|
||||
|
||||
const minIndent = codeLines
|
||||
.filter((line) => line.trim().length > 0)
|
||||
.reduce((min, line) => {
|
||||
const match = line.match(/^\s*/);
|
||||
return Math.min(min, match ? match[0].length : 0);
|
||||
}, Infinity);
|
||||
|
||||
const formattedCodeLines = codeLines.map((line) => line.slice(minIndent));
|
||||
codeText = formattedCodeLines.join("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// handle unknown languages. They won't have a `node.position.start.offset`
|
||||
if (!codeText) {
|
||||
const findTextNode = (node: any): string | null => {
|
||||
if (node.type === "text") {
|
||||
return node.value;
|
||||
}
|
||||
let finalResult = "";
|
||||
if (node.children) {
|
||||
for (const child of node.children) {
|
||||
const result = findTextNode(child);
|
||||
if (result) {
|
||||
finalResult += result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return finalResult;
|
||||
};
|
||||
|
||||
codeText = findTextNode(props.node);
|
||||
}
|
||||
|
||||
const handleCopy = (event: React.MouseEvent) => {
|
||||
event.preventDefault();
|
||||
if (!codeText) {
|
||||
return;
|
||||
}
|
||||
|
||||
navigator.clipboard.writeText(codeText).then(() => {
|
||||
setCopied(true);
|
||||
setTimeout(() => setCopied(false), 2000); // Reset copy status after 2 seconds
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="overflow-x-hidden">
|
||||
<div className="flex mx-3 py-2 text-xs">
|
||||
@@ -143,4 +154,4 @@ export function CodeBlock({
|
||||
</pre>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
33
web/src/app/chat/message/MemoizedTextComponents.tsx
Normal file
33
web/src/app/chat/message/MemoizedTextComponents.tsx
Normal file
@@ -0,0 +1,33 @@
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import React, { memo } from "react";
|
||||
|
||||
export const MemoizedLink = memo((props: any) => {
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
);
|
||||
} else if (value?.toString().startsWith("[")) {
|
||||
return <Citation link={rest?.href}>{rest.children}</Citation>;
|
||||
} else {
|
||||
return (
|
||||
<a
|
||||
onMouseDown={() =>
|
||||
rest.href ? window.open(rest.href, "_blank") : undefined
|
||||
}
|
||||
className="cursor-pointer text-link hover:text-link-hover"
|
||||
>
|
||||
{rest.children}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
export const MemoizedParagraph = memo(({ node, ...props }: any) => (
|
||||
<p {...props} className="text-default" />
|
||||
));
|
||||
|
||||
MemoizedLink.displayName = "MemoizedLink";
|
||||
MemoizedParagraph.displayName = "MemoizedParagraph";
|
||||
@@ -8,14 +8,7 @@ import {
|
||||
FiGlobe,
|
||||
} from "react-icons/fi";
|
||||
import { FeedbackType } from "../types";
|
||||
import {
|
||||
Dispatch,
|
||||
SetStateAction,
|
||||
useContext,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import React, { useContext, useEffect, useMemo, useRef, useState } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import {
|
||||
DanswerDocument,
|
||||
@@ -46,12 +39,7 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
|
||||
import {
|
||||
ThumbsUpIcon,
|
||||
ThumbsDownIcon,
|
||||
LikeFeedback,
|
||||
DislikeFeedback,
|
||||
} from "@/components/icons/icons";
|
||||
import { LikeFeedback, DislikeFeedback } from "@/components/icons/icons";
|
||||
import {
|
||||
CustomTooltip,
|
||||
TooltipGroup,
|
||||
@@ -65,6 +53,7 @@ import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
|
||||
import RegenerateOption from "../RegenerateOption";
|
||||
import { LlmOverride } from "@/lib/hooks";
|
||||
import { ContinueGenerating } from "./ContinueMessage";
|
||||
import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -367,41 +356,8 @@ export const AIMessage = ({
|
||||
key={messageId}
|
||||
className="prose max-w-full text-base"
|
||||
components={{
|
||||
a: (props) => {
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
);
|
||||
} else if (
|
||||
value?.toString().startsWith("[")
|
||||
) {
|
||||
// for some reason <a> tags cause the onClick to not apply
|
||||
// and the links are unclickable
|
||||
// TODO: fix the fact that you have to double click to follow link
|
||||
// for the first link
|
||||
return (
|
||||
<Citation link={rest?.href}>
|
||||
{rest.children}
|
||||
</Citation>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<a
|
||||
onMouseDown={() =>
|
||||
rest.href
|
||||
? window.open(rest.href, "_blank")
|
||||
: undefined
|
||||
}
|
||||
className="cursor-pointer text-link hover:text-link-hover"
|
||||
>
|
||||
{rest.children}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
},
|
||||
a: MemoizedLink,
|
||||
p: MemoizedParagraph,
|
||||
code: (props) => (
|
||||
<CodeBlock
|
||||
className="w-full"
|
||||
@@ -409,9 +365,6 @@ export const AIMessage = ({
|
||||
content={content as string}
|
||||
/>
|
||||
),
|
||||
p: ({ node, ...props }) => (
|
||||
<p {...props} className="text-default" />
|
||||
),
|
||||
}}
|
||||
remarkPlugins={[remarkGfm]}
|
||||
rehypePlugins={[
|
||||
|
||||
72
web/src/app/chat/modal/MakePublicAssistantModal.tsx
Normal file
72
web/src/app/chat/modal/MakePublicAssistantModal.tsx
Normal file
@@ -0,0 +1,72 @@
|
||||
import { ModalWrapper } from "@/components/modals/ModalWrapper";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
|
||||
export function MakePublicAssistantModal({
|
||||
isPublic,
|
||||
onShare,
|
||||
onClose,
|
||||
}: {
|
||||
isPublic: boolean;
|
||||
onShare: (shared: boolean) => void;
|
||||
onClose: () => void;
|
||||
}) {
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-3xl">
|
||||
<div className="space-y-6">
|
||||
<h2 className="text-2xl font-bold text-emphasis">
|
||||
{isPublic ? "Public Assistant" : "Make Assistant Public"}
|
||||
</h2>
|
||||
|
||||
<Text>
|
||||
This assistant is currently{" "}
|
||||
<span className="font-semibold">
|
||||
{isPublic ? "public" : "private"}
|
||||
</span>
|
||||
.
|
||||
{isPublic
|
||||
? " Anyone can currently access this assistant."
|
||||
: " Only you can access this assistant."}
|
||||
</Text>
|
||||
|
||||
<Divider />
|
||||
|
||||
{isPublic ? (
|
||||
<div className="space-y-4">
|
||||
<Text>
|
||||
To restrict access to this assistant, you can make it private
|
||||
again.
|
||||
</Text>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
await onShare?.(false);
|
||||
onClose();
|
||||
}}
|
||||
size="sm"
|
||||
color="red"
|
||||
>
|
||||
Make Assistant Private
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
<Text>
|
||||
Making this assistant public will allow anyone with the link to
|
||||
view and use it. Ensure that all content and capabilities of the
|
||||
assistant are safe to share.
|
||||
</Text>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
await onShare?.(true);
|
||||
onClose();
|
||||
}}
|
||||
size="sm"
|
||||
color="green"
|
||||
>
|
||||
Make Assistant Public
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
@@ -4,10 +4,15 @@ import React, { forwardRef, useCallback, useState } from "react";
|
||||
import { debounce } from "lodash";
|
||||
import { Text } from "@tremor/react";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { destructureValue, structureValue } from "@/lib/llm/utils";
|
||||
import {
|
||||
checkLLMSupportsImageInput,
|
||||
destructureValue,
|
||||
structureValue,
|
||||
} from "@/lib/llm/utils";
|
||||
import { updateModelOverrideForChatSession } from "../../lib";
|
||||
import { GearIcon } from "@/components/icons/icons";
|
||||
import { LlmList } from "@/components/llm/LLMList";
|
||||
import { checkPersonaRequiresImageGeneration } from "@/app/admin/assistants/lib";
|
||||
|
||||
interface LlmTabProps {
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
@@ -15,13 +20,24 @@ interface LlmTabProps {
|
||||
openModelSettings: () => void;
|
||||
chatSessionId?: number;
|
||||
close: () => void;
|
||||
currentAssistant: Persona;
|
||||
}
|
||||
|
||||
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||
(
|
||||
{ llmOverrideManager, chatSessionId, currentLlm, close, openModelSettings },
|
||||
{
|
||||
llmOverrideManager,
|
||||
chatSessionId,
|
||||
currentLlm,
|
||||
close,
|
||||
openModelSettings,
|
||||
currentAssistant,
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
const requiresImageGeneration =
|
||||
checkPersonaRequiresImageGeneration(currentAssistant);
|
||||
|
||||
const { llmProviders } = useChatContext();
|
||||
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
|
||||
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
||||
@@ -55,6 +71,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||
</button>
|
||||
</div>
|
||||
<LlmList
|
||||
requiresImageGeneration={requiresImageGeneration}
|
||||
llmProviders={llmProviders}
|
||||
currentLlm={currentLlm}
|
||||
onSelect={(value: string | null) => {
|
||||
|
||||
@@ -1,453 +0,0 @@
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { DocumentSet, Tag, ValidSources } from "@/lib/types";
|
||||
import { SourceMetadata } from "@/lib/search/interfaces";
|
||||
import {
|
||||
FiBook,
|
||||
FiBookmark,
|
||||
FiCalendar,
|
||||
FiFilter,
|
||||
FiMap,
|
||||
FiTag,
|
||||
FiX,
|
||||
} from "react-icons/fi";
|
||||
import { DateRangePickerValue } from "@tremor/react";
|
||||
import { listSourceMetadata } from "@/lib/sources";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { BasicClickable } from "@/components/BasicClickable";
|
||||
import { ControlledPopup, DefaultDropdownElement } from "@/components/Dropdown";
|
||||
import { getXDaysAgo } from "@/lib/dateUtils";
|
||||
import { SourceSelectorProps } from "@/components/search/filtering/Filters";
|
||||
import { containsObject, objectsAreEquivalent } from "@/lib/contains";
|
||||
|
||||
enum FilterType {
|
||||
Source = "Source",
|
||||
KnowledgeSet = "Knowledge Set",
|
||||
TimeRange = "Time Range",
|
||||
Tag = "Tag",
|
||||
}
|
||||
|
||||
function SelectedBubble({
|
||||
children,
|
||||
onClick,
|
||||
}: {
|
||||
children: string | JSX.Element;
|
||||
onClick: () => void;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
className={
|
||||
"flex text-xs cursor-pointer items-center border border-border " +
|
||||
"py-1 rounded-lg px-2 w-fit select-none hover:bg-hover"
|
||||
}
|
||||
onClick={onClick}
|
||||
>
|
||||
{children}
|
||||
<FiX className="ml-2" size={14} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function SelectFilterType({
|
||||
onSelect,
|
||||
hasSources,
|
||||
hasKnowledgeSets,
|
||||
hasTags,
|
||||
}: {
|
||||
onSelect: (filterType: FilterType) => void;
|
||||
hasSources: boolean;
|
||||
hasKnowledgeSets: boolean;
|
||||
hasTags: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div className="w-64">
|
||||
{hasSources && (
|
||||
<DefaultDropdownElement
|
||||
key={FilterType.Source}
|
||||
name={FilterType.Source}
|
||||
icon={FiMap}
|
||||
onSelect={() => onSelect(FilterType.Source)}
|
||||
isSelected={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasKnowledgeSets && (
|
||||
<DefaultDropdownElement
|
||||
key={FilterType.KnowledgeSet}
|
||||
name={FilterType.KnowledgeSet}
|
||||
icon={FiBook}
|
||||
onSelect={() => onSelect(FilterType.KnowledgeSet)}
|
||||
isSelected={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{hasTags && (
|
||||
<DefaultDropdownElement
|
||||
key={FilterType.Tag}
|
||||
name={FilterType.Tag}
|
||||
icon={FiTag}
|
||||
onSelect={() => onSelect(FilterType.Tag)}
|
||||
isSelected={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
<DefaultDropdownElement
|
||||
key={FilterType.TimeRange}
|
||||
name={FilterType.TimeRange}
|
||||
icon={FiCalendar}
|
||||
onSelect={() => onSelect(FilterType.TimeRange)}
|
||||
isSelected={false}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function SourcesSection({
|
||||
sources,
|
||||
selectedSources,
|
||||
onSelect,
|
||||
}: {
|
||||
sources: SourceMetadata[];
|
||||
selectedSources: string[];
|
||||
onSelect: (source: SourceMetadata) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="w-64">
|
||||
{sources.map((source) => (
|
||||
<DefaultDropdownElement
|
||||
key={source.internalName}
|
||||
name={source.displayName}
|
||||
icon={source.icon}
|
||||
onSelect={() => onSelect(source)}
|
||||
isSelected={selectedSources.includes(source.internalName)}
|
||||
includeCheckbox
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function KnowledgeSetsSection({
|
||||
documentSets,
|
||||
selectedDocumentSets,
|
||||
onSelect,
|
||||
}: {
|
||||
documentSets: DocumentSet[];
|
||||
selectedDocumentSets: string[];
|
||||
onSelect: (documentSetName: string) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="w-64">
|
||||
{documentSets.map((documentSet) => (
|
||||
<DefaultDropdownElement
|
||||
key={documentSet.name}
|
||||
name={documentSet.name}
|
||||
icon={FiBookmark}
|
||||
onSelect={() => onSelect(documentSet.name)}
|
||||
isSelected={selectedDocumentSets.includes(documentSet.name)}
|
||||
includeCheckbox
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const LAST_30_DAYS = "Last 30 days";
|
||||
const LAST_7_DAYS = "Last 7 days";
|
||||
const TODAY = "Today";
|
||||
|
||||
function TimeRangeSection({
|
||||
selectedTimeRange,
|
||||
onSelect,
|
||||
}: {
|
||||
selectedTimeRange: string | null;
|
||||
onSelect: (timeRange: DateRangePickerValue) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="w-64">
|
||||
<DefaultDropdownElement
|
||||
key={LAST_30_DAYS}
|
||||
name={LAST_30_DAYS}
|
||||
onSelect={() =>
|
||||
onSelect({
|
||||
to: new Date(),
|
||||
from: getXDaysAgo(30),
|
||||
selectValue: LAST_30_DAYS,
|
||||
})
|
||||
}
|
||||
isSelected={selectedTimeRange === LAST_30_DAYS}
|
||||
/>
|
||||
|
||||
<DefaultDropdownElement
|
||||
key={LAST_7_DAYS}
|
||||
name={LAST_7_DAYS}
|
||||
onSelect={() =>
|
||||
onSelect({
|
||||
to: new Date(),
|
||||
from: getXDaysAgo(7),
|
||||
selectValue: LAST_7_DAYS,
|
||||
})
|
||||
}
|
||||
isSelected={selectedTimeRange === LAST_7_DAYS}
|
||||
/>
|
||||
|
||||
<DefaultDropdownElement
|
||||
key={TODAY}
|
||||
name={TODAY}
|
||||
onSelect={() =>
|
||||
onSelect({
|
||||
to: new Date(),
|
||||
from: getXDaysAgo(1),
|
||||
selectValue: TODAY,
|
||||
})
|
||||
}
|
||||
isSelected={selectedTimeRange === TODAY}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function TagsSection({
|
||||
availableTags,
|
||||
selectedTags,
|
||||
onSelect,
|
||||
}: {
|
||||
availableTags: Tag[];
|
||||
selectedTags: Tag[];
|
||||
onSelect: (tag: Tag) => void;
|
||||
}) {
|
||||
const [filterValue, setFilterValue] = useState("");
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (inputRef.current) {
|
||||
inputRef.current.focus();
|
||||
}
|
||||
}, []);
|
||||
|
||||
const filterValueLower = filterValue.toLowerCase();
|
||||
const filteredTags = filterValueLower
|
||||
? availableTags.filter(
|
||||
(tags) =>
|
||||
tags.tag_value.toLowerCase().startsWith(filterValueLower) ||
|
||||
tags.tag_key.toLowerCase().startsWith(filterValueLower)
|
||||
)
|
||||
: availableTags;
|
||||
|
||||
return (
|
||||
<div className="w-96">
|
||||
<div className="max-h-48 overflow-y-auto">
|
||||
{filteredTags.length > 0 ? (
|
||||
filteredTags.map((tag) => (
|
||||
<DefaultDropdownElement
|
||||
key={tag.tag_key + tag.tag_value}
|
||||
name={
|
||||
<div className="max-w-full break-all line-clamp-1 text-ellipsis">
|
||||
{tag.tag_key}
|
||||
<b>=</b>
|
||||
{tag.tag_value}
|
||||
</div>
|
||||
}
|
||||
onSelect={() => onSelect(tag)}
|
||||
isSelected={selectedTags.includes(tag)}
|
||||
includeCheckbox
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<div className="text-sm px-2 py-2">No matching tags found</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="mx-2 mb-2 pt-2 border-t border-border">
|
||||
<input
|
||||
ref={inputRef}
|
||||
className="w-full border border-border py-0.5 px-2 rounded text-sm h-8 "
|
||||
placeholder="Find a tag"
|
||||
value={filterValue}
|
||||
onChange={(event) => setFilterValue(event.target.value)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function ChatFilters({
|
||||
timeRange,
|
||||
setTimeRange,
|
||||
selectedSources,
|
||||
setSelectedSources,
|
||||
selectedDocumentSets,
|
||||
setSelectedDocumentSets,
|
||||
selectedTags,
|
||||
setSelectedTags,
|
||||
availableDocumentSets,
|
||||
existingSources,
|
||||
availableTags,
|
||||
}: SourceSelectorProps) {
|
||||
const [filtersOpen, setFiltersOpen] = useState(false);
|
||||
const handleFiltersToggle = (value: boolean) => {
|
||||
setSelectedFilterType(null);
|
||||
setFiltersOpen(value);
|
||||
};
|
||||
const [selectedFilterType, setSelectedFilterType] =
|
||||
useState<FilterType | null>(null);
|
||||
|
||||
const handleSourceSelect = (source: SourceMetadata) => {
|
||||
setSelectedSources((prev: SourceMetadata[]) => {
|
||||
const prevSourceNames = prev.map((source) => source.internalName);
|
||||
if (prevSourceNames.includes(source.internalName)) {
|
||||
return prev.filter((s) => s.internalName !== source.internalName);
|
||||
} else {
|
||||
return [...prev, source];
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const handleDocumentSetSelect = (documentSetName: string) => {
|
||||
setSelectedDocumentSets((prev: string[]) => {
|
||||
if (prev.includes(documentSetName)) {
|
||||
return prev.filter((s) => s !== documentSetName);
|
||||
} else {
|
||||
return [...prev, documentSetName];
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const handleTagToggle = (tag: Tag) => {
|
||||
setSelectedTags((prev) => {
|
||||
if (containsObject(prev, tag)) {
|
||||
return prev.filter((t) => !objectsAreEquivalent(t, tag));
|
||||
} else {
|
||||
return [...prev, tag];
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const allSources = listSourceMetadata();
|
||||
const availableSources = allSources.filter((source) =>
|
||||
existingSources.includes(source.internalName)
|
||||
);
|
||||
|
||||
let popupDisplay = null;
|
||||
if (selectedFilterType === FilterType.Source) {
|
||||
popupDisplay = (
|
||||
<SourcesSection
|
||||
sources={availableSources}
|
||||
selectedSources={selectedSources.map((source) => source.internalName)}
|
||||
onSelect={handleSourceSelect}
|
||||
/>
|
||||
);
|
||||
} else if (selectedFilterType === FilterType.KnowledgeSet) {
|
||||
popupDisplay = (
|
||||
<KnowledgeSetsSection
|
||||
documentSets={availableDocumentSets}
|
||||
selectedDocumentSets={selectedDocumentSets}
|
||||
onSelect={handleDocumentSetSelect}
|
||||
/>
|
||||
);
|
||||
} else if (selectedFilterType === FilterType.TimeRange) {
|
||||
popupDisplay = (
|
||||
<TimeRangeSection
|
||||
selectedTimeRange={timeRange?.selectValue || null}
|
||||
onSelect={(timeRange) => {
|
||||
setTimeRange(timeRange);
|
||||
handleFiltersToggle(!filtersOpen);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
} else if (selectedFilterType === FilterType.Tag) {
|
||||
popupDisplay = (
|
||||
<TagsSection
|
||||
availableTags={availableTags}
|
||||
selectedTags={selectedTags}
|
||||
onSelect={handleTagToggle}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
popupDisplay = (
|
||||
<SelectFilterType
|
||||
onSelect={(filterType) => setSelectedFilterType(filterType)}
|
||||
hasSources={availableSources.length > 0}
|
||||
hasKnowledgeSets={availableDocumentSets.length > 0}
|
||||
hasTags={availableTags.length > 0}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<ControlledPopup
|
||||
isOpen={filtersOpen}
|
||||
setIsOpen={handleFiltersToggle}
|
||||
popupContent={popupDisplay}
|
||||
>
|
||||
<div className="flex">
|
||||
<BasicClickable onClick={() => handleFiltersToggle(!filtersOpen)}>
|
||||
<div className="flex text-xs">
|
||||
<FiFilter className="my-auto mr-1" /> Filter
|
||||
</div>
|
||||
</BasicClickable>
|
||||
</div>
|
||||
</ControlledPopup>
|
||||
|
||||
<div className="flex ml-4">
|
||||
{((timeRange && timeRange.selectValue !== undefined) ||
|
||||
selectedSources.length > 0 ||
|
||||
selectedDocumentSets.length > 0) && (
|
||||
<p className="text-xs my-auto mr-1">Currently applied:</p>
|
||||
)}
|
||||
<div className="flex flex-wrap gap-x-2">
|
||||
{timeRange && timeRange.selectValue && (
|
||||
<SelectedBubble onClick={() => setTimeRange(null)}>
|
||||
<div className="flex">{timeRange.selectValue}</div>
|
||||
</SelectedBubble>
|
||||
)}
|
||||
{existingSources.length > 0 &&
|
||||
selectedSources.map((source) => (
|
||||
<SelectedBubble
|
||||
key={source.internalName}
|
||||
onClick={() => handleSourceSelect(source)}
|
||||
>
|
||||
<>
|
||||
<SourceIcon sourceType={source.internalName} iconSize={16} />
|
||||
<span className="ml-2">{source.displayName}</span>
|
||||
</>
|
||||
</SelectedBubble>
|
||||
))}
|
||||
{selectedDocumentSets.length > 0 &&
|
||||
selectedDocumentSets.map((documentSetName) => (
|
||||
<SelectedBubble
|
||||
key={documentSetName}
|
||||
onClick={() => handleDocumentSetSelect(documentSetName)}
|
||||
>
|
||||
<>
|
||||
<div>
|
||||
<FiBookmark />
|
||||
</div>
|
||||
<span className="ml-2">{documentSetName}</span>
|
||||
</>
|
||||
</SelectedBubble>
|
||||
))}
|
||||
|
||||
{selectedTags.length > 0 &&
|
||||
selectedTags.map((tag) => (
|
||||
<SelectedBubble
|
||||
key={tag.tag_key + tag.tag_value}
|
||||
onClick={() => handleTagToggle(tag)}
|
||||
>
|
||||
<>
|
||||
<div>
|
||||
<FiTag />
|
||||
</div>
|
||||
<span className="ml-1 max-w-[100px] text-ellipsis line-clamp-1 break-all">
|
||||
{tag.tag_key}
|
||||
<b>=</b>
|
||||
{tag.tag_value}
|
||||
</span>
|
||||
</>
|
||||
</SelectedBubble>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -10,7 +10,7 @@ export const SEARCH_PARAM_NAMES = {
|
||||
MODEL_VERSION: "model-version",
|
||||
SYSTEM_PROMPT: "system-prompt",
|
||||
// user message
|
||||
USER_MESSAGE: "user-message",
|
||||
USER_PROMPT: "user-prompt",
|
||||
SUBMIT_ON_LOAD: "submit-on-load",
|
||||
// chat title
|
||||
TITLE: "title",
|
||||
|
||||
9
web/src/app/config/timeRange.tsx
Normal file
9
web/src/app/config/timeRange.tsx
Normal file
@@ -0,0 +1,9 @@
|
||||
import { getXDaysAgo, getXYearsAgo } from "@/lib/dateUtils";
|
||||
|
||||
export const timeRangeValues = [
|
||||
{ label: "Last 2 years", value: getXYearsAgo(2) },
|
||||
{ label: "Last year", value: getXYearsAgo(1) },
|
||||
{ label: "Last 30 days", value: getXDaysAgo(30) },
|
||||
{ label: "Last 7 days", value: getXDaysAgo(7) },
|
||||
{ label: "Today", value: getXDaysAgo(1) },
|
||||
];
|
||||
@@ -44,7 +44,7 @@ export function Modal({
|
||||
return (
|
||||
<div
|
||||
onMouseDown={handleMouseDown}
|
||||
className={`fixed inset-0 bg-black bg-opacity-25 backdrop-blur-sm
|
||||
className={`fixed inset-0 bg-black bg-opacity-25 backdrop-blur-sm h-full
|
||||
flex items-center justify-center z-50 transition-opacity duration-300 ease-in-out`}
|
||||
>
|
||||
<div
|
||||
@@ -72,7 +72,7 @@ export function Modal({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex w-full flex-col justify-stretch">
|
||||
<div className="w-full flex flex-col h-full justify-stretch">
|
||||
{title && (
|
||||
<>
|
||||
<div className="flex mb-4">
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import { CodeBlock } from "@/app/chat/message/CodeBlock";
|
||||
import {
|
||||
MemoizedLink,
|
||||
MemoizedParagraph,
|
||||
} from "@/app/chat/message/MemoizedTextComponents";
|
||||
import React from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
@@ -18,17 +22,8 @@ export const MinimalMarkdown: React.FC<MinimalMarkdownProps> = ({
|
||||
<ReactMarkdown
|
||||
className={`w-full text-wrap break-word ${className}`}
|
||||
components={{
|
||||
a: ({ node, ...props }) => (
|
||||
<a
|
||||
{...props}
|
||||
className="text-sm text-link hover:text-link-hover"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
/>
|
||||
),
|
||||
p: ({ node, ...props }) => (
|
||||
<p {...props} className="text-wrap break-word text-sm m-0 w-full" />
|
||||
),
|
||||
a: MemoizedLink,
|
||||
p: MemoizedParagraph,
|
||||
code: useCodeBlock
|
||||
? (props) => (
|
||||
<CodeBlock className="w-full" {...props} content={content} />
|
||||
|
||||
32
web/src/components/filters/TimeRangeSelector.tsx
Normal file
32
web/src/components/filters/TimeRangeSelector.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
import { DefaultDropdownElement } from "../Dropdown";
|
||||
export function TimeRangeSelector({
|
||||
value,
|
||||
onValueChange,
|
||||
className,
|
||||
timeRangeValues,
|
||||
}: {
|
||||
value: any;
|
||||
onValueChange: any;
|
||||
className: any;
|
||||
|
||||
timeRangeValues: { label: string; value: Date }[];
|
||||
}) {
|
||||
return (
|
||||
<div className={className}>
|
||||
{timeRangeValues.map((timeRangeValue) => (
|
||||
<DefaultDropdownElement
|
||||
key={timeRangeValue.label}
|
||||
name={timeRangeValue.label}
|
||||
onSelect={() =>
|
||||
onValueChange({
|
||||
to: new Date(),
|
||||
from: timeRangeValue.value,
|
||||
selectValue: timeRangeValue.label,
|
||||
})
|
||||
}
|
||||
isSelected={value?.selectValue === timeRangeValue.label}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import React from "react";
|
||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||
import { structureValue } from "@/lib/llm/utils";
|
||||
import { checkLLMSupportsImageInput, structureValue } from "@/lib/llm/utils";
|
||||
import {
|
||||
getProviderIcon,
|
||||
LLMProviderDescriptor,
|
||||
@@ -13,6 +13,7 @@ interface LlmListProps {
|
||||
userDefault?: string | null;
|
||||
scrollable?: boolean;
|
||||
hideProviderIcon?: boolean;
|
||||
requiresImageGeneration?: boolean;
|
||||
}
|
||||
|
||||
export const LlmList: React.FC<LlmListProps> = ({
|
||||
@@ -21,6 +22,7 @@ export const LlmList: React.FC<LlmListProps> = ({
|
||||
onSelect,
|
||||
userDefault,
|
||||
scrollable,
|
||||
requiresImageGeneration,
|
||||
}) => {
|
||||
const llmOptionsByProvider: {
|
||||
[provider: string]: {
|
||||
@@ -76,21 +78,26 @@ export const LlmList: React.FC<LlmListProps> = ({
|
||||
User Default (currently {getDisplayNameForModel(userDefault)})
|
||||
</button>
|
||||
)}
|
||||
{llmOptions.map(({ name, icon, value }, index) => (
|
||||
<button
|
||||
type="button"
|
||||
key={index}
|
||||
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
|
||||
currentLlm == name
|
||||
? "bg-background-200"
|
||||
: "bg-background hover:bg-background-100"
|
||||
} text-left rounded`}
|
||||
onClick={() => onSelect(value)}
|
||||
>
|
||||
{icon({ size: 16 })}
|
||||
{getDisplayNameForModel(name)}
|
||||
</button>
|
||||
))}
|
||||
|
||||
{llmOptions.map(({ name, icon, value }, index) => {
|
||||
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
key={index}
|
||||
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
|
||||
currentLlm == name
|
||||
? "bg-background-200"
|
||||
: "bg-background hover:bg-background-100"
|
||||
} text-left rounded`}
|
||||
onClick={() => onSelect(value)}
|
||||
>
|
||||
{icon({ size: 16 })}
|
||||
{getDisplayNameForModel(name)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -18,7 +18,7 @@ export default function ExceptionTraceModal({
|
||||
title="Full Exception Trace"
|
||||
onOutsideClick={onOutsideClick}
|
||||
>
|
||||
<div className="overflow-y-auto mb-6">
|
||||
<div className="overflow-y-auto include-scrollbar pr-3 h-full mb-6">
|
||||
<div className="mb-6">
|
||||
{!copyClicked ? (
|
||||
<div
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { getXDaysAgo } from "@/lib/dateUtils";
|
||||
import { DateRangePickerValue } from "@tremor/react";
|
||||
import { FiCalendar, FiChevronDown, FiXCircle } from "react-icons/fi";
|
||||
import { CustomDropdown, DefaultDropdownElement } from "../Dropdown";
|
||||
|
||||
export const LAST_30_DAYS = "Last 30 days";
|
||||
export const LAST_7_DAYS = "Last 7 days";
|
||||
export const TODAY = "Today";
|
||||
import { CustomDropdown } from "../Dropdown";
|
||||
import { timeRangeValues } from "@/app/config/timeRange";
|
||||
import { TimeRangeSelector } from "@/components/filters/TimeRangeSelector";
|
||||
|
||||
export function DateRangeSelector({
|
||||
value,
|
||||
@@ -20,59 +17,23 @@ export function DateRangeSelector({
|
||||
<div>
|
||||
<CustomDropdown
|
||||
dropdown={
|
||||
<div
|
||||
<TimeRangeSelector
|
||||
value={value}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded-lg
|
||||
flex
|
||||
flex-col
|
||||
w-64
|
||||
max-h-96
|
||||
overflow-y-auto
|
||||
flex
|
||||
overscroll-contain`}
|
||||
>
|
||||
<DefaultDropdownElement
|
||||
key={LAST_30_DAYS}
|
||||
name={LAST_30_DAYS}
|
||||
onSelect={() =>
|
||||
onValueChange({
|
||||
to: new Date(),
|
||||
from: getXDaysAgo(30),
|
||||
selectValue: LAST_30_DAYS,
|
||||
})
|
||||
}
|
||||
isSelected={value?.selectValue === LAST_30_DAYS}
|
||||
/>
|
||||
|
||||
<DefaultDropdownElement
|
||||
key={LAST_7_DAYS}
|
||||
name={LAST_7_DAYS}
|
||||
onSelect={() =>
|
||||
onValueChange({
|
||||
to: new Date(),
|
||||
from: getXDaysAgo(7),
|
||||
selectValue: LAST_7_DAYS,
|
||||
})
|
||||
}
|
||||
isSelected={value?.selectValue === LAST_7_DAYS}
|
||||
/>
|
||||
|
||||
<DefaultDropdownElement
|
||||
key={TODAY}
|
||||
name={TODAY}
|
||||
onSelect={() =>
|
||||
onValueChange({
|
||||
to: new Date(),
|
||||
from: getXDaysAgo(1),
|
||||
selectValue: TODAY,
|
||||
})
|
||||
}
|
||||
isSelected={value?.selectValue === TODAY}
|
||||
/>
|
||||
</div>
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded-lg
|
||||
flex
|
||||
flex-col
|
||||
w-64
|
||||
max-h-96
|
||||
overflow-y-auto
|
||||
flex
|
||||
overscroll-contain`}
|
||||
timeRangeValues={timeRangeValues}
|
||||
onValueChange={onValueChange}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import {
|
||||
DanswerDocument,
|
||||
DocumentRelevance,
|
||||
Relevance,
|
||||
SearchDanswerDocument,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { DocumentFeedbackBlock } from "./DocumentFeedbackBlock";
|
||||
@@ -12,11 +11,10 @@ import { PopupSpec } from "../admin/connectors/Popup";
|
||||
import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge";
|
||||
import { SourceIcon } from "../SourceIcon";
|
||||
import { MetadataBadge } from "../MetadataBadge";
|
||||
import { BookIcon, CheckmarkIcon, LightBulbIcon, XIcon } from "../icons/icons";
|
||||
import { BookIcon, LightBulbIcon } from "../icons/icons";
|
||||
|
||||
import { FaStar } from "react-icons/fa";
|
||||
import { FiTag } from "react-icons/fi";
|
||||
import { DISABLE_LLM_DOC_RELEVANCE } from "@/lib/constants";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { CustomTooltip, TooltipGroup } from "../tooltip/CustomTooltip";
|
||||
import { WarningCircle } from "@phosphor-icons/react";
|
||||
|
||||
@@ -42,6 +42,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
||||
if (!results[0].ok) {
|
||||
if (results[0].status === 403) {
|
||||
settings = {
|
||||
gpu_enabled: false,
|
||||
chat_page_enabled: true,
|
||||
search_page_enabled: true,
|
||||
default_page: "search",
|
||||
|
||||
@@ -14,15 +14,20 @@ export function orderAssistantsForUser(
|
||||
])
|
||||
);
|
||||
|
||||
assistants = assistants.filter((assistant) =>
|
||||
let filteredAssistants = assistants.filter((assistant) =>
|
||||
chosenAssistantsSet.has(assistant.id)
|
||||
);
|
||||
|
||||
assistants.sort((a, b) => {
|
||||
if (filteredAssistants.length == 0) {
|
||||
return assistants;
|
||||
}
|
||||
|
||||
filteredAssistants.sort((a, b) => {
|
||||
const orderA = assistantOrderMap.get(a.id) ?? Number.MAX_SAFE_INTEGER;
|
||||
const orderB = assistantOrderMap.get(b.id) ?? Number.MAX_SAFE_INTEGER;
|
||||
return orderA - orderB;
|
||||
});
|
||||
return filteredAssistants;
|
||||
}
|
||||
|
||||
return assistants;
|
||||
|
||||
@@ -873,6 +873,9 @@ export function createConnectorValidationSchema(
|
||||
});
|
||||
}
|
||||
|
||||
export const defaultPruneFreqDays = 30; // 30 days
|
||||
export const defaultRefreshFreqMinutes = 30; // 30 minutes
|
||||
|
||||
// CONNECTORS
|
||||
export interface ConnectorBase<T> {
|
||||
name: string;
|
||||
|
||||
@@ -5,6 +5,13 @@ export function getXDaysAgo(daysAgo: number) {
|
||||
return daysAgoDate;
|
||||
}
|
||||
|
||||
export function getXYearsAgo(yearsAgo: number) {
|
||||
const today = new Date();
|
||||
const yearsAgoDate = new Date(today);
|
||||
yearsAgoDate.setFullYear(yearsAgoDate.getFullYear() - yearsAgo);
|
||||
return yearsAgoDate;
|
||||
}
|
||||
|
||||
export const timestampToDateString = (timestamp: string) => {
|
||||
const date = new Date(timestamp);
|
||||
const year = date.getFullYear();
|
||||
|
||||
@@ -62,7 +62,7 @@ export function getLLMProviderOverrideForPersona(
|
||||
return null;
|
||||
}
|
||||
|
||||
const MODEL_NAMES_SUPPORTING_IMAGES = [
|
||||
const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-vision-preview",
|
||||
@@ -84,8 +84,31 @@ const MODEL_NAMES_SUPPORTING_IMAGES = [
|
||||
];
|
||||
|
||||
export function checkLLMSupportsImageInput(model: string) {
|
||||
return MODEL_NAMES_SUPPORTING_IMAGES.some((modelName) => modelName === model);
|
||||
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
|
||||
(modelName) => modelName === model
|
||||
);
|
||||
}
|
||||
|
||||
const MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT = [
|
||||
["openai", "gpt-4o"],
|
||||
["openai", "gpt-4o-mini"],
|
||||
["openai", "gpt-4-vision-preview"],
|
||||
["openai", "gpt-4-turbo"],
|
||||
["openai", "gpt-4-1106-vision-preview"],
|
||||
["azure", "gpt-4o"],
|
||||
["azure", "gpt-4o-mini"],
|
||||
["azure", "gpt-4-vision-preview"],
|
||||
["azure", "gpt-4-turbo"],
|
||||
["azure", "gpt-4-1106-vision-preview"],
|
||||
];
|
||||
|
||||
export function checkLLMSupportsImageOutput(provider: string, model: string) {
|
||||
return MODEL_PROVIDER_PAIRS_SUPPORTING_IMAGE_OUTPUT.some(
|
||||
(modelProvider) =>
|
||||
modelProvider[0] === provider && modelProvider[1] === model
|
||||
);
|
||||
}
|
||||
|
||||
export const structureValue = (
|
||||
name: string,
|
||||
provider: string,
|
||||
|
||||
Reference in New Issue
Block a user