mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
44 Commits
eval/split
...
eval/nomic
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d48c79de9 | ||
|
|
a4d71e08aa | ||
|
|
546bfbd24b | ||
|
|
27824d6cc6 | ||
|
|
9d5c4ad634 | ||
|
|
9b32003816 | ||
|
|
8bc4123ed7 | ||
|
|
d58aaf7a59 | ||
|
|
a0056a1b3c | ||
|
|
d2584c773a | ||
|
|
807bef8ada | ||
|
|
5afddacbb2 | ||
|
|
4fb6a88f1e | ||
|
|
7057be6a88 | ||
|
|
91be8e7bfb | ||
|
|
9651ea828b | ||
|
|
6ee74bd0d1 | ||
|
|
48a0d29a5c | ||
|
|
6ff8e6c0ea | ||
|
|
2470c68506 | ||
|
|
866bc803b1 | ||
|
|
9c6084bd0d | ||
|
|
a0b46c60c6 | ||
|
|
4029233df0 | ||
|
|
6c88c0156c | ||
|
|
33332d08f2 | ||
|
|
17005fb705 | ||
|
|
48a7fe80b1 | ||
|
|
1276732409 | ||
|
|
f91b92a898 | ||
|
|
6222f533be | ||
|
|
1b49d17239 | ||
|
|
2f5f19642e | ||
|
|
6db4634871 | ||
|
|
5cfed45cef | ||
|
|
581ffde35a | ||
|
|
6313e6d91d | ||
|
|
c09c94bf32 | ||
|
|
0e8ba111c8 | ||
|
|
2ba24b1734 | ||
|
|
44820b4909 | ||
|
|
eb3e7610fc | ||
|
|
7fbbb174bb | ||
|
|
3854ca11af |
@@ -0,0 +1,41 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
||||
27
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
27
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""add expiry time
|
||||
Revision ID: 91ffac7e65b3
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-24 09:39:56.462242
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91ffac7e65b3"
|
||||
down_revision = "795b20b85b4b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("user", "oidc_expiry")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,6 +1,8 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
@@ -52,6 +54,7 @@ from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -92,12 +95,20 @@ def user_needs_to_be_verified() -> bool:
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(
|
||||
email: str,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if email.count("@") != 1:
|
||||
@@ -147,7 +158,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> models.UP:
|
||||
verify_email_in_whitelist(user_create.email)
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
@@ -173,7 +184,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
return await super().oauth_callback( # type: ignore
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
@@ -185,6 +196,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: google oauth expires after 1hr. We don't want to force the user to
|
||||
# re-authenticate that frequently, so for now we'll just ignore this for
|
||||
# google oauth users
|
||||
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -227,10 +246,12 @@ cookie_transport = CookieTransport(
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
strategy = DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="database",
|
||||
|
||||
@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -80,7 +80,7 @@ def should_prune_cc_pair(
|
||||
return True
|
||||
return False
|
||||
|
||||
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
@@ -89,11 +89,9 @@ def should_prune_cc_pair(
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
logger.info("Another Connector is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
|
||||
@@ -20,7 +20,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pai
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -299,9 +299,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress__no_commit(attempt)
|
||||
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
|
||||
db_session.commit()
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
|
||||
return attempt
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@@ -271,6 +271,8 @@ def kickoff_indexing_jobs(
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.embedding_model)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
|
||||
@@ -51,7 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
@@ -187,37 +187,46 @@ def _handle_internet_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
# If files are already provided, don't run the search tool
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
else None
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
return None
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
if (
|
||||
new_msg_req.query_override
|
||||
or (
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
)
|
||||
or new_msg_req.search_doc_ids
|
||||
or DISABLE_LLM_CHOOSE_SEARCH
|
||||
):
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override
|
||||
else None
|
||||
)
|
||||
# if we are using selected docs, just put something here so the Tool doesn't need
|
||||
# to build its own args via an LLM call
|
||||
if new_msg_req.search_doc_ids:
|
||||
args = {"query": new_msg_req.message}
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
@@ -253,7 +262,6 @@ def stream_chat_message_objects(
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -361,6 +369,14 @@ def stream_chat_message_objects(
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
@@ -576,11 +592,7 @@ def stream_chat_message_objects(
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=(
|
||||
_check_should_force_search(new_msg_req)
|
||||
if search_tool and len(tools) == 1
|
||||
else None
|
||||
),
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
|
||||
@@ -214,8 +214,8 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
|
||||
@@ -44,7 +44,6 @@ QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
# For chunking/processing chunks
|
||||
MAX_CHUNK_TITLE_LEN = 1000
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
|
||||
@@ -12,17 +12,13 @@ import os
|
||||
# The useable models configured as below must be SentenceTransformer compatible
|
||||
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
|
||||
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
|
||||
DOCUMENT_ENCODER_MODEL = (
|
||||
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
|
||||
DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
|
||||
# If the below is changed, Vespa deployment must also be changed
|
||||
DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 768)
|
||||
# Model should be chosen with 512 context size, ideally don't change this
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
NORMALIZE_EMBEDDINGS = (
|
||||
os.environ.get("NORMALIZE_EMBEDDINGS") or "true"
|
||||
).lower() == "true"
|
||||
NORMALIZE_EMBEDDINGS = False
|
||||
|
||||
# Old default model settings, which are needed for an automatic easy upgrade
|
||||
OLD_DEFAULT_DOCUMENT_ENCODER_MODEL = "thenlper/gte-small"
|
||||
@@ -34,8 +30,8 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
|
||||
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
|
||||
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||
ASYM_QUERY_PREFIX = "search_query: "
|
||||
ASYM_PASSAGE_PREFIX = "search_document: "
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
|
||||
@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
|
||||
continue
|
||||
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments]
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if jira.fields.priority:
|
||||
metadata_dict["priority"] = jira.fields.priority.name
|
||||
if jira.fields.status:
|
||||
metadata_dict["status"] = jira.fields.status.name
|
||||
if jira.fields.resolution:
|
||||
metadata_dict["resolution"] = jira.fields.resolution.name
|
||||
if jira.fields.labels:
|
||||
metadata_dict["label"] = jira.fields.labels
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
|
||||
@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
@@ -114,7 +114,9 @@ class DocumentBase(BaseModel):
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
def get_title_for_document_index(self) -> str | None:
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
) -> str | None:
|
||||
# If title is explicitly empty, return a None here for embedding purposes
|
||||
if self.title == "":
|
||||
return None
|
||||
|
||||
@@ -15,6 +15,7 @@ from playwright.sync_api import BrowserContext
|
||||
from playwright.sync_api import Playwright
|
||||
from playwright.sync_api import sync_playwright
|
||||
from requests_oauthlib import OAuth2Session # type:ignore
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID
|
||||
@@ -83,6 +84,13 @@ def check_internet_connection(url: str) -> None:
|
||||
try:
|
||||
response = requests.get(url, timeout=3)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.SSLError as e:
|
||||
cause = (
|
||||
e.args[0].reason
|
||||
if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError)
|
||||
else e.args
|
||||
)
|
||||
raise Exception(f"SSL error {str(cause)}")
|
||||
except (requests.RequestException, ValueError):
|
||||
raise Exception(f"Unable to reach {url} - check your internet connection")
|
||||
|
||||
|
||||
@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
|
||||
@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
|
||||
@@ -65,9 +65,12 @@ def get_inprogress_index_attempts(
|
||||
|
||||
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
|
||||
"""This eagerly loads the connector and credential so that the db_session can be expired
|
||||
before running long-living indexing jobs, which causes increasing memory usage"""
|
||||
before running long-living indexing jobs, which causes increasing memory usage.
|
||||
|
||||
Results are ordered by time_created (oldest to newest)."""
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
stmt = stmt.order_by(IndexAttempt.time_created)
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
|
||||
)
|
||||
@@ -75,11 +78,13 @@ def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
|
||||
return list(new_attempts.all())
|
||||
|
||||
|
||||
def mark_attempt_in_progress__no_commit(
|
||||
def mark_attempt_in_progress(
|
||||
index_attempt: IndexAttempt,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
index_attempt.status = IndexingStatus.IN_PROGRESS
|
||||
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_attempt_succeeded(
|
||||
|
||||
@@ -1,15 +1,41 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Delete existing relationships
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
# Add new relationships from given group_ids
|
||||
if group_ids:
|
||||
new_relationships = [
|
||||
LLMProvider__UserGroup(
|
||||
llm_provider_id=llm_provider_id,
|
||||
user_group_id=group_id,
|
||||
)
|
||||
for group_id in group_ids
|
||||
]
|
||||
db_session.add_all(new_relationships)
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
@@ -36,36 +62,35 @@ def upsert_llm_provider(
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider:
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
provider=llm_provider.provider,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
default_model_name=llm_provider.default_model_name,
|
||||
fast_default_model_name=llm_provider.fast_default_model_name,
|
||||
model_names=llm_provider.model_names,
|
||||
is_default_provider=None,
|
||||
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
existing_llm_provider.is_public = llm_provider.is_public
|
||||
|
||||
if not existing_llm_provider.id:
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(llm_provider_model)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
@@ -74,8 +99,29 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
if not user:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
user_groups_subquery = (
|
||||
select(User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id == user.id)
|
||||
.subquery()
|
||||
)
|
||||
access_conditions = or_(
|
||||
LLMProviderModel.is_public,
|
||||
LLMProviderModel.id.in_( # User is part of a group that has access
|
||||
select(LLMProvider__UserGroup.llm_provider_id).where(
|
||||
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
|
||||
)
|
||||
),
|
||||
)
|
||||
stmt = stmt.where(access_conditions)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
@@ -119,6 +165,13 @@ def remove_embedding_provider(
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
# Remove LLMProvider's dependent relationships
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||
)
|
||||
)
|
||||
# Remove LLMProvider
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from uuid import UUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Enum
|
||||
@@ -120,6 +121,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.ARRAY(Integer), nullable=True
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
@@ -932,6 +937,13 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="llm_provider__user_group",
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
@@ -1109,7 +1121,6 @@ class Persona(Base):
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
# These are only defaults, users can select from all if desired
|
||||
prompts: Mapped[list[Prompt]] = relationship(
|
||||
@@ -1137,6 +1148,7 @@ class Persona(Base):
|
||||
viewonly=True,
|
||||
)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="persona__user_group",
|
||||
@@ -1360,6 +1372,17 @@ class Persona__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider__UserGroup(Base):
|
||||
__tablename__ = "llm_provider__user_group"
|
||||
|
||||
llm_provider_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("llm_provider.id"), primary_key=True
|
||||
)
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__UserGroup(Base):
|
||||
__tablename__ = "document_set__user_group"
|
||||
|
||||
|
||||
@@ -153,41 +153,43 @@ schema DANSWER_CHUNK_NAME {
|
||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||
}
|
||||
|
||||
function title_vector_score() {
|
||||
# This must be separate function for normalize_linear to work
|
||||
function vector_score() {
|
||||
expression {
|
||||
# If no title, the full vector score comes from the content embedding
|
||||
#query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
|
||||
if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
|
||||
(query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))) +
|
||||
((1 - query(title_content_ratio)) * closeness(field, embeddings))
|
||||
}
|
||||
}
|
||||
|
||||
# This must be separate function for normalize_linear to work
|
||||
function keyword_score() {
|
||||
expression {
|
||||
(query(title_content_ratio) * bm25(title)) +
|
||||
((1 - query(title_content_ratio)) * bm25(content))
|
||||
}
|
||||
}
|
||||
|
||||
first-phase {
|
||||
expression: closeness(field, embeddings)
|
||||
expression: vector_score
|
||||
}
|
||||
|
||||
# Weighted average between Vector Search and BM-25
|
||||
# Each is a weighted average between the Title and Content fields
|
||||
# Finally each doc is boosted by it's user feedback based boost and recency
|
||||
# If any embedding or index field is missing, it just receives a score of 0
|
||||
# Assumptions:
|
||||
# - For a given query + corpus, the BM-25 scores will be relatively similar in distribution
|
||||
# therefore not normalizing before combining.
|
||||
# - For documents without title, it gets a score of 0 for that and this is ok as documents
|
||||
# without any title match should be penalized.
|
||||
global-phase {
|
||||
expression {
|
||||
(
|
||||
# Weighted Vector Similarity Score
|
||||
(
|
||||
query(alpha) * (
|
||||
(query(title_content_ratio) * normalize_linear(title_vector_score))
|
||||
+
|
||||
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
|
||||
)
|
||||
)
|
||||
|
||||
+
|
||||
|
||||
(query(alpha) * normalize_linear(vector_score)) +
|
||||
# Weighted Keyword Similarity Score
|
||||
(
|
||||
(1 - query(alpha)) * (
|
||||
(query(title_content_ratio) * normalize_linear(bm25(title)))
|
||||
+
|
||||
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
|
||||
)
|
||||
)
|
||||
((1 - query(alpha)) * normalize_linear(keyword_score))
|
||||
)
|
||||
# Boost based on user feedback
|
||||
* document_boost
|
||||
@@ -202,6 +204,8 @@ schema DANSWER_CHUNK_NAME {
|
||||
bm25(content)
|
||||
closeness(field, title_embedding)
|
||||
closeness(field, embeddings)
|
||||
keyword_score
|
||||
vector_score
|
||||
document_boost
|
||||
recency_bias
|
||||
closest(embeddings)
|
||||
|
||||
@@ -331,12 +331,18 @@ def _index_vespa_chunk(
|
||||
document = chunk.source_document
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
embeddings = chunk.embeddings
|
||||
|
||||
if chunk.embeddings.full_embedding is None:
|
||||
embeddings.full_embedding = chunk.title_embedding
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
if m_c_embed is None:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
|
||||
else:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
|
||||
@@ -346,11 +352,15 @@ def _index_vespa_chunk(
|
||||
BLURB: remove_invalid_unicode_chars(chunk.blurb),
|
||||
TITLE: remove_invalid_unicode_chars(title) if title else None,
|
||||
SKIP_TITLE_EMBEDDING: not title,
|
||||
CONTENT: remove_invalid_unicode_chars(chunk.content),
|
||||
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
|
||||
# natural language representation of the metadata section
|
||||
CONTENT: remove_invalid_unicode_chars(
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
|
||||
),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
# which contains the title prefix and metadata suffix
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
|
||||
SOURCE_TYPE: str(document.source.value),
|
||||
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
|
||||
@@ -358,7 +368,7 @@ def _index_vespa_chunk(
|
||||
METADATA: json.dumps(document.metadata),
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
|
||||
METADATA_SUFFIX: chunk.metadata_suffix,
|
||||
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
BOOST: chunk.boost,
|
||||
|
||||
@@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=default_fast_model_name,
|
||||
model_names=None,
|
||||
is_public=True,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
||||
@@ -6,7 +6,6 @@ from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.constants import SECTION_SEPARATOR
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
@@ -15,12 +14,12 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
from llama_index.text_splitter import SentenceSplitter # type:ignore
|
||||
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
@@ -28,6 +27,8 @@ if TYPE_CHECKING:
|
||||
CHUNK_OVERLAP = 0
|
||||
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
|
||||
# overwhelm the actual contents of the chunk
|
||||
# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix
|
||||
# could be another 128 tokens leaving 256 for the actual contents
|
||||
MAX_METADATA_PERCENTAGE = 0.25
|
||||
CHUNK_MIN_CONTENT = 256
|
||||
|
||||
@@ -36,14 +37,7 @@ logger = setup_logger()
|
||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_size: int) -> str:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
def extract_blurb(text: str, blurb_splitter: "SentenceSplitter") -> str:
|
||||
return blurb_splitter.split_text(text)[0]
|
||||
|
||||
|
||||
@@ -52,33 +46,25 @@ def chunk_large_section(
|
||||
section_link_text: str,
|
||||
document: Document,
|
||||
start_chunk_id: int,
|
||||
tokenizer: "AutoTokenizer",
|
||||
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
blurb: str,
|
||||
chunk_splitter: "SentenceSplitter",
|
||||
title_prefix: str = "",
|
||||
metadata_suffix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
blurb = extract_blurb(section_text, blurb_size)
|
||||
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
split_texts = sentence_aware_splitter.split_text(section_text)
|
||||
split_texts = chunk_splitter.split_text(section_text)
|
||||
|
||||
chunks = [
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
|
||||
content_summary=chunk_str,
|
||||
content=chunk_str,
|
||||
source_links={0: section_link_text},
|
||||
section_continuation=(chunk_ind != 0),
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
for chunk_ind, chunk_str in enumerate(split_texts)
|
||||
]
|
||||
@@ -86,42 +72,87 @@ def chunk_large_section(
|
||||
|
||||
|
||||
def _get_metadata_suffix_for_document_index(
|
||||
metadata: dict[str, str | list[str]]
|
||||
) -> str:
|
||||
metadata: dict[str, str | list[str]], include_separator: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding
|
||||
and a string of all of the values for the keyword search
|
||||
|
||||
For example, if we have the following metadata:
|
||||
{
|
||||
"author": "John Doe",
|
||||
"space": "Engineering"
|
||||
}
|
||||
The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe
|
||||
and Engineering. The keys are repeat and much more noisy.
|
||||
"""
|
||||
if not metadata:
|
||||
return ""
|
||||
return "", ""
|
||||
|
||||
metadata_str = "Metadata:\n"
|
||||
metadata_values = []
|
||||
for key, value in metadata.items():
|
||||
if key in get_metadata_keys_to_ignore():
|
||||
continue
|
||||
|
||||
value_str = ", ".join(value) if isinstance(value, list) else value
|
||||
|
||||
if isinstance(value, list):
|
||||
metadata_values.extend(value)
|
||||
else:
|
||||
metadata_values.append(value)
|
||||
|
||||
metadata_str += f"\t{key} - {value_str}\n"
|
||||
return metadata_str.strip()
|
||||
|
||||
metadata_semantic = metadata_str.strip()
|
||||
metadata_keyword = " ".join(metadata_values)
|
||||
|
||||
if include_separator:
|
||||
return RETURN_SEPARATOR + metadata_semantic, RETURN_SEPARATOR + metadata_keyword
|
||||
return metadata_semantic, metadata_keyword
|
||||
|
||||
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
blurb_size: int = BLURB_SIZE, # Used for both title and content
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
tokenizer = get_default_tokenizer()
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize, chunk_size=blurb_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
chunk_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=chunk_tok_size,
|
||||
chunk_overlap=subsection_overlap,
|
||||
)
|
||||
|
||||
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(tokenizer.tokenize(title_prefix))
|
||||
|
||||
metadata_suffix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_suffix_keyword = ""
|
||||
metadata_tokens = 0
|
||||
if include_metadata:
|
||||
metadata = _get_metadata_suffix_for_document_index(document.metadata)
|
||||
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
|
||||
(
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
) = _get_metadata_suffix_for_document_index(
|
||||
document.metadata, include_separator=True
|
||||
)
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix_semantic))
|
||||
|
||||
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
|
||||
metadata_suffix = ""
|
||||
# Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model
|
||||
# context, there is no limit for the keyword component
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
|
||||
@@ -130,7 +161,7 @@ def chunk_document(
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
content_token_limit = chunk_tok_size
|
||||
title_prefix = ""
|
||||
metadata_suffix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
@@ -151,12 +182,13 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
)
|
||||
link_offsets = {}
|
||||
@@ -167,12 +199,11 @@ def chunk_document(
|
||||
section_link_text=section_link_text,
|
||||
document=document,
|
||||
start_chunk_id=len(chunks),
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=content_token_limit,
|
||||
chunk_overlap=subsection_overlap,
|
||||
blurb_size=blurb_size,
|
||||
chunk_splitter=chunk_splitter,
|
||||
blurb=extract_blurb(section_text, blurb_splitter),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix=metadata_suffix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunks.extend(large_section_chunks)
|
||||
continue
|
||||
@@ -193,12 +224,13 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
)
|
||||
link_offsets = {0: section_link_text}
|
||||
@@ -211,12 +243,13 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
|
||||
@@ -4,7 +4,6 @@ from abc import abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
@@ -14,8 +13,7 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
@@ -66,48 +64,38 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
)
|
||||
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[IndexChunk]:
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, list[float]] = {}
|
||||
title_embed_dict: dict[str, list[float] | None] = {}
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunk_texts = []
|
||||
chunk_texts: list[str] = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
# The whole chunk including the prefix/suffix is included in the overall vector representation
|
||||
chunk_texts.append(
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
|
||||
)
|
||||
mini_chunk_texts = (
|
||||
split_chunk_text_into_mini_chunks(chunk.content_summary)
|
||||
split_chunk_text_into_mini_chunks(chunk.content)
|
||||
if enable_mini_chunk
|
||||
else []
|
||||
)
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
|
||||
# Batching for embedding
|
||||
text_batches = batch_list(chunk_texts, batch_size)
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
len_text_batches = len(text_batches)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(
|
||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||
)
|
||||
|
||||
# Replace line above with the line below for easy debugging of indexing flow
|
||||
# skipping the actual model
|
||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||
embeddings = self.embedding_model.encode(
|
||||
chunk_texts, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
|
||||
chunk_titles = {
|
||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||
@@ -116,16 +104,15 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# Drop any None or empty strings
|
||||
chunk_titles_list = [title for title in chunk_titles if title]
|
||||
|
||||
# Embed Titles in batches
|
||||
title_batches = batch_list(chunk_titles_list, batch_size)
|
||||
len_title_batches = len(title_batches)
|
||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||
if chunk_titles_list:
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
title_batch, text_type=EmbedTextType.PASSAGE
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||
{
|
||||
title: vector
|
||||
for title, vector in zip(chunk_titles_list, title_embeddings)
|
||||
}
|
||||
)
|
||||
|
||||
# Mapping embeddings to chunks
|
||||
@@ -184,4 +171,6 @@ def get_embedding_model_from_db_embedding_model(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
api_key=db_embedding_model.api_key,
|
||||
)
|
||||
|
||||
@@ -124,6 +124,19 @@ def index_doc_batch(
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
# Skip documents that have neither title nor content
|
||||
documents_to_process = []
|
||||
for document in documents:
|
||||
if not document.title and not any(
|
||||
section.text.strip() for section in document.sections
|
||||
):
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content"
|
||||
)
|
||||
else:
|
||||
documents_to_process.append(document)
|
||||
documents = documents_to_process
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
db_docs = get_documents_by_ids(
|
||||
document_ids=document_ids,
|
||||
|
||||
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
Embedding = list[float] | None
|
||||
|
||||
|
||||
class ChunkEmbedding(BaseModel):
|
||||
@@ -36,15 +36,13 @@ class DocAwareChunk(BaseChunk):
|
||||
# During inference we only have access to the document id and do not reconstruct the Document
|
||||
source_document: Document
|
||||
|
||||
# The Vespa documents require a separate highlight field. Since it is stored as a duplicate anyway,
|
||||
# it's easier to just store a not prefixed/suffixed string for the highlighting
|
||||
# Also during the chunking, this non-prefixed/suffixed string is used for mini-chunks
|
||||
content_summary: str
|
||||
title_prefix: str
|
||||
|
||||
# During indexing we also (optionally) build a metadata string from the metadata dict
|
||||
# This is also indexed so that we can strip it out after indexing, this way it supports
|
||||
# multiple iterations of metadata representation for backwards compatibility
|
||||
metadata_suffix: str
|
||||
metadata_suffix_semantic: str
|
||||
metadata_suffix_keyword: str
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
|
||||
@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
@@ -99,6 +99,7 @@ class Answer:
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
@@ -107,10 +108,8 @@ class Answer:
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# if specified, tells the LLM to always this tool
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
# if set to True, then never use the LLMs provided tool-calling functonality
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
@@ -129,6 +128,7 @@ class Answer:
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
@@ -187,7 +187,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool and self.force_use_tool.args is not None:
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
@@ -221,7 +221,7 @@ class Answer:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
@@ -240,12 +240,26 @@ class Answer:
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool = [
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
][0]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool and self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
@@ -286,7 +300,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
if self.force_use_tool:
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
@@ -303,7 +317,7 @@ class Answer:
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
|
||||
@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
|
||||
@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
|
||||
@@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
|
||||
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
|
||||
),
|
||||
model_names=model_names,
|
||||
is_public=True,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -16,10 +15,8 @@ from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
@@ -28,7 +25,6 @@ from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
@@ -37,53 +33,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LLM_TOKENIZER: Any = None
|
||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Encoding:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
|
||||
return _LLM_TOKENIZER
|
||||
|
||||
|
||||
def get_default_llm_token_encode() -> Callable[[str], Any]:
|
||||
global _LLM_TOKENIZER_ENCODE
|
||||
if _LLM_TOKENIZER_ENCODE is None:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
if isinstance(tokenizer, Encoding):
|
||||
return tokenizer.encode # type: ignore
|
||||
|
||||
# Currently only supports OpenAI encoder
|
||||
raise ValueError("Invalid Encoder selected")
|
||||
|
||||
return _LLM_TOKENIZER_ENCODE
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: Encoding
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
) -> list[InferenceChunk]:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
|
||||
@@ -50,8 +50,8 @@ from danswer.db.standard_answer import create_initial_default_standard_answer_ca
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
from httpx import HTTPError
|
||||
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.natural_language_processing.utils import get_default_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@@ -20,50 +19,13 @@ from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||
|
||||
|
||||
def clean_model_name(model_str: str) -> str:
|
||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
return _TOKENIZER[0]
|
||||
|
||||
|
||||
def build_model_server_url(
|
||||
model_server_host: str,
|
||||
model_server_port: int,
|
||||
@@ -91,6 +53,7 @@ class EmbeddingModel:
|
||||
provider_type: str | None,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
retrim_content: bool = False,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -99,32 +62,90 @@ class EmbeddingModel:
|
||||
self.passage_prefix = passage_prefix
|
||||
self.normalize = normalize
|
||||
self.model_name = model_name
|
||||
self.retrim_content = retrim_content
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
||||
prefixed_texts = [self.query_prefix + text for text in texts]
|
||||
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
|
||||
prefixed_texts = [self.passage_prefix + text for text in texts]
|
||||
else:
|
||||
prefixed_texts = texts
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
) -> list[list[float] | None]:
|
||||
if not texts:
|
||||
logger.warning("No texts to be embedded")
|
||||
return []
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
)
|
||||
if self.retrim_content:
|
||||
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
|
||||
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
|
||||
# However this slight miscounting is very unlikely to have any material impact.
|
||||
texts = [
|
||||
tokenizer_trim_content(
|
||||
content=text,
|
||||
desired_length=self.max_seq_length,
|
||||
tokenizer=get_default_tokenizer(),
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
response.raise_for_status()
|
||||
if self.provider_type:
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
EmbedResponse(**response.json()).embeddings
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[list[float] | None] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
@@ -136,7 +157,7 @@ class CrossEncoderEnsembleModel:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
|
||||
response = requests.post(
|
||||
@@ -199,7 +220,7 @@ def warm_up_encoders(
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
# retry the whole server
|
||||
wait_time = 5
|
||||
for attempt in range(20):
|
||||
for _ in range(20):
|
||||
try:
|
||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
return
|
||||
100
backend/danswer/natural_language_processing/utils.py
Normal file
100
backend/danswer/natural_language_processing/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.core import Encoding
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
|
||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||
_LLM_TOKENIZER: Any = None
|
||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
return _TOKENIZER[0]
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Encoding:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
|
||||
return _LLM_TOKENIZER
|
||||
|
||||
|
||||
def get_default_llm_token_encode() -> Callable[[str], Any]:
|
||||
global _LLM_TOKENIZER_ENCODE
|
||||
if _LLM_TOKENIZER_ENCODE is None:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
if isinstance(tokenizer, Encoding):
|
||||
return tokenizer.encode # type: ignore
|
||||
|
||||
# Currently only supports OpenAI encoder
|
||||
raise ValueError("Invalid Encoder selected")
|
||||
|
||||
return _LLM_TOKENIZER_ENCODE
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: Encoding
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
) -> list[InferenceChunk]:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
@@ -34,7 +34,7 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
@@ -206,6 +206,7 @@ def stream_answer_objects(
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
),
|
||||
@@ -256,6 +257,9 @@ def stream_answer_objects(
|
||||
)
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
|
||||
@@ -268,9 +272,6 @@ def stream_answer_objects(
|
||||
|
||||
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SEARCH_EVALUATION_ID:
|
||||
evaluation_response = LLMRelevanceSummaryResponse(
|
||||
relevance_summaries=packet.response
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -154,6 +155,7 @@ class SearchPipeline:
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
|
||||
@@ -173,9 +175,11 @@ class SearchPipeline:
|
||||
expanded_inference_sections = []
|
||||
|
||||
# Full doc setting takes priority
|
||||
|
||||
if self.search_query.full_doc:
|
||||
seen_document_ids = set()
|
||||
unique_chunks = []
|
||||
|
||||
# This preserves the ordering since the chunks are retrieved in score order
|
||||
for chunk in retrieved_chunks:
|
||||
if chunk.document_id not in seen_document_ids:
|
||||
@@ -195,7 +199,6 @@ class SearchPipeline:
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
@@ -240,32 +243,35 @@ class SearchPipeline:
|
||||
merged_ranges = [
|
||||
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
|
||||
]
|
||||
flat_ranges = [r for ranges in merged_ranges for r in ranges]
|
||||
|
||||
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
|
||||
flattened_inference_chunks: list[InferenceChunk] = []
|
||||
parallel_functions_with_args = []
|
||||
|
||||
for chunk_range in flat_ranges:
|
||||
functions_with_args.append(
|
||||
(
|
||||
# If Large Chunks are introduced, additional filters need to be added here
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
# Only need the document_id here, just use any chunk in the range is fine
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
# There is no chunk level permissioning, this expansion around chunks
|
||||
# can be assumed to be safe
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
)
|
||||
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
|
||||
if above == below == 0:
|
||||
flattened_inference_chunks.extend(chunk_range.chunks)
|
||||
|
||||
# list of list of inference chunks where the inner list needs to be combined for content
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
flattened_inference_chunks = [
|
||||
chunk for sublist in list_inference_chunks for chunk in sublist
|
||||
]
|
||||
else:
|
||||
parallel_functions_with_args.append(
|
||||
(
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if parallel_functions_with_args:
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
parallel_functions_with_args, allow_failures=False
|
||||
)
|
||||
for inference_chunks in list_inference_chunks:
|
||||
flattened_inference_chunks.extend(inference_chunks)
|
||||
|
||||
doc_chunk_ind_to_chunk = {
|
||||
(chunk.document_id, chunk.chunk_id): chunk
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import cast
|
||||
|
||||
import numpy
|
||||
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
@@ -12,6 +12,9 @@ from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.natural_language_processing.search_nlp_models import (
|
||||
CrossEncoderEnsembleModel,
|
||||
)
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
@@ -20,7 +23,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
@@ -58,8 +60,14 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
||||
if chunk.content.startswith(chunk.title):
|
||||
return chunk.content[len(chunk.title) :].lstrip()
|
||||
|
||||
if chunk.content.startswith(chunk.title[:MAX_CHUNK_TITLE_LEN]):
|
||||
return chunk.content[MAX_CHUNK_TITLE_LEN:].lstrip()
|
||||
# BLURB SIZE is by token instead of char but each token is at least 1 char
|
||||
# If this prefix matches the content, it's assumed the title was prepended
|
||||
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
|
||||
return (
|
||||
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
|
||||
if RETURN_SEPARATOR in chunk.content
|
||||
else chunk.content
|
||||
)
|
||||
|
||||
return chunk.content
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
|
||||
from danswer.natural_language_processing.search_nlp_models import IntentModel
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.search.search_nlp_models import IntentModel
|
||||
from danswer.server.query_and_chat.models import HelperResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import nltk # type:ignore
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
@@ -11,6 +12,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
@@ -20,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -143,7 +144,9 @@ def doc_index_retrieval(
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
@@ -152,7 +155,9 @@ def doc_index_retrieval(
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
|
||||
@@ -94,7 +94,7 @@ def history_based_query_rephrase(
|
||||
llm: LLM,
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
skip_first_rephrase: bool = False,
|
||||
skip_first_rephrase: bool = True,
|
||||
prompt_template: str = HISTORY_QUERY_REPHRASE,
|
||||
) -> str:
|
||||
# Globally disabled, just use the exact user query
|
||||
|
||||
@@ -96,6 +96,8 @@ def upsert_ingestion_doc(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -132,6 +134,8 @@ def upsert_ingestion_doc(
|
||||
normalize=sec_db_embedding_model.normalize,
|
||||
query_prefix=sec_db_embedding_model.query_prefix,
|
||||
passage_prefix=sec_db_embedding_model.passage_prefix,
|
||||
api_key=sec_db_embedding_model.api_key,
|
||||
provider_type=sec_db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
|
||||
@@ -9,7 +9,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.server.documents.models import ChunkInfo
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
@@ -42,7 +42,7 @@ def test_embedding_configuration(
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||
|
||||
@@ -147,10 +147,10 @@ def set_provider_as_default(
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
_: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
LLMProviderDescriptor.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
|
||||
]
|
||||
|
||||
@@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
|
||||
custom_config: dict[str, str] | None
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_public: bool = True
|
||||
groups: list[int] | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
is_public=llm_provider_model.is_public,
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -14,13 +15,15 @@ from danswer.db.models import SlackBotConfig as SlackBotConfigModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.db.models import User
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import User as UserModel
|
||||
pass
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
@@ -46,9 +49,17 @@ class UserInfo(BaseModel):
|
||||
is_verified: bool
|
||||
role: UserRole
|
||||
preferences: UserPreferences
|
||||
oidc_expiry: datetime | None = None
|
||||
current_token_created_at: datetime | None = None
|
||||
current_token_expiry_length: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user: "UserModel") -> "UserInfo":
|
||||
def from_model(
|
||||
cls,
|
||||
user: User,
|
||||
current_token_created_at: datetime | None = None,
|
||||
expiry_length: int | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
@@ -57,6 +68,9 @@ class UserInfo(BaseModel):
|
||||
is_verified=user.is_verified,
|
||||
role=user.role,
|
||||
preferences=(UserPreferences(chosen_assistants=user.chosen_assistants)),
|
||||
oidc_expiry=user.oidc_expiry,
|
||||
current_token_created_at=current_token_created_at,
|
||||
current_token_expiry_length=expiry_length,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,7 +165,9 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
# by an optional `PersonaSnapshot` object. Keeping it like this
|
||||
# for now for simplicity / speed of development
|
||||
document_sets: list[int] | None
|
||||
persona_id: int | None # NOTE: only one of `document_sets` / `persona_id` should be set
|
||||
persona_id: (
|
||||
int | None
|
||||
) # NOTE: only one of `document_sets` / `persona_id` should be set
|
||||
channel_names: list[str]
|
||||
respond_tag_only: bool = False
|
||||
respond_to_bots: bool = False
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Body
|
||||
@@ -6,6 +7,9 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -19,9 +23,11 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.db.users import list_users
|
||||
@@ -117,9 +123,9 @@ def list_all_users(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
status=UserStatus.LIVE
|
||||
if user.is_active
|
||||
else UserStatus.DEACTIVATED,
|
||||
status=(
|
||||
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
|
||||
),
|
||||
)
|
||||
for user in users
|
||||
],
|
||||
@@ -246,9 +252,35 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
||||
return UserRoleResponse(role=user.role)
|
||||
|
||||
|
||||
def get_current_token_creation(
|
||||
user: User | None, db_session: Session
|
||||
) -> datetime | None:
|
||||
if user is None:
|
||||
return None
|
||||
try:
|
||||
result = db_session.execute(
|
||||
select(AccessToken)
|
||||
.where(AccessToken.user_id == user.id) # type: ignore
|
||||
.order_by(desc(Column("created_at")))
|
||||
.limit(1)
|
||||
)
|
||||
access_token = result.scalar_one_or_none()
|
||||
|
||||
if access_token:
|
||||
return access_token.created_at
|
||||
else:
|
||||
logger.error("No AccessToken found for user")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching AccessToken: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def verify_user_logged_in(
|
||||
user: User | None = Depends(optional_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserInfo:
|
||||
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
|
||||
# to enforce user verification here - the frontend always wants to get the info about
|
||||
@@ -264,7 +296,14 @@ def verify_user_logged_in(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User Not Authenticated"
|
||||
)
|
||||
|
||||
return UserInfo.from_model(user)
|
||||
token_created_at = get_current_token_creation(user, db_session)
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
return user_info
|
||||
|
||||
|
||||
"""APIs to adjust user preferences"""
|
||||
|
||||
@@ -45,7 +45,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.headers import get_litellm_additional_request_headers
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
|
||||
@@ -90,7 +90,7 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
parent_message_id: int | None
|
||||
# New message contents
|
||||
message: str
|
||||
# file's that we should attach to this message
|
||||
# Files that we should attach to this message
|
||||
file_descriptors: list[FileDescriptor]
|
||||
# If no prompt provided, uses the largest prompt of the chat session
|
||||
# but really this should be explicitly specified, only in the simplified APIs is this inferred
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
class ForceUseTool(BaseModel):
|
||||
# Could be not a forced usage of the tool but still have args, in which case
|
||||
# if the tool is called, then those args are applied instead of what the LLM
|
||||
# wanted to call it with
|
||||
force_use: bool
|
||||
tool_name: str
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
@@ -16,25 +18,10 @@ class ForceUseTool(BaseModel):
|
||||
return {"type": "function", "function": {"name": self.tool_name}}
|
||||
|
||||
|
||||
def modify_message_chain_for_force_use_tool(
|
||||
messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None
|
||||
) -> list[BaseMessage]:
|
||||
"""NOTE: modifies `messages` in place."""
|
||||
if not force_use_tool:
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_call["args"] = force_use_tool.args or {}
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def filter_tools_for_force_tool_use(
|
||||
tools: list[Tool], force_use_tool: ForceUseTool | None = None
|
||||
tools: list[Tool], force_use_tool: ForceUseTool
|
||||
) -> list[Tool]:
|
||||
if not force_use_tool:
|
||||
if not force_use_tool.force_use:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name == force_use_tool.tool_name]
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
|
||||
|
||||
def build_tool_message(
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
|
||||
from tiktoken import Encoding
|
||||
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
@@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
@@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
|
||||
|
||||
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
@@ -40,110 +41,133 @@ router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
|
||||
def _initialize_client(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
) -> Any:
|
||||
if provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(api_key)
|
||||
)
|
||||
project_id = json.loads(api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
||||
self.api_key = api_key
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: str,
|
||||
# Only for Google as is needed on client setup
|
||||
self.model = model
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = self._initialize_client()
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _initialize_client(self) -> Any:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(
|
||||
self.model or DEFAULT_VERTEX_MODEL
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def encode(
|
||||
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
||||
) -> list[list[float]]:
|
||||
return [
|
||||
self.embed(text=text, text_type=text_type, model=model_name)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
def embed(
|
||||
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
||||
) -> list[float]:
|
||||
logger.debug(f"Embedding text with provider: {self.provider}")
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||
def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
response = self.client.embeddings.create(input=text, model=model)
|
||||
return response.data[0].embedding
|
||||
# OpenAI does not seem to provide truncation option, however
|
||||
# the context lengths used by Danswer currently are smaller than the max token length
|
||||
# for OpenAI embeddings so it's not a big deal
|
||||
response = self.client.embeddings.create(input=texts, model=model)
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
|
||||
def _embed_cohere(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = self.client.embed(
|
||||
texts=[text],
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
return response.embeddings[0]
|
||||
return response.embeddings
|
||||
|
||||
def _embed_voyage(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
response = self.client.embed(text, model=model, input_type=embedding_type)
|
||||
return response.embeddings[0]
|
||||
# Similar to Cohere, the API server will do approximate size chunking
|
||||
# it's acceptable to miss by a few tokens
|
||||
response = self.client.embed(
|
||||
texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True, # Also this is default
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
def _embed_vertex(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
embedding = self.client.get_embeddings(
|
||||
embeddings = self.client.get_embeddings(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
]
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # Also this is default
|
||||
)
|
||||
return embedding[0].values
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
) -> list[list[float] | None]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -166,7 +190,7 @@ def get_embedding_model(
|
||||
|
||||
if model_name not in _GLOBAL_MODELS_DICT:
|
||||
logger.info(f"Loading {model_name}")
|
||||
model = SentenceTransformer(model_name)
|
||||
model = SentenceTransformer(model_name, trust_remote_code=True)
|
||||
model.max_seq_length = max_context_length
|
||||
_GLOBAL_MODELS_DICT[model_name] = model
|
||||
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
|
||||
@@ -212,34 +236,83 @@ def embed_text(
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
) -> list[list[float]]:
|
||||
if provider_type is not None:
|
||||
prefix: str | None,
|
||||
) -> list[list[float] | None]:
|
||||
non_empty_texts = []
|
||||
empty_indices = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
if text.strip():
|
||||
non_empty_texts.append(text)
|
||||
else:
|
||||
empty_indices.append(idx)
|
||||
|
||||
# Third party API based embedding model
|
||||
if not non_empty_texts:
|
||||
embeddings = []
|
||||
elif provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
if api_key is None:
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
if prefix:
|
||||
# This may change in the future if some providers require the user
|
||||
# to manually append a prefix but this is not the case currently
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.encode(texts, model_name, text_type)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=non_empty_texts,
|
||||
model_name=model_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
elif model_name is not None:
|
||||
hosted_model = get_embedding_model(
|
||||
prefixed_texts = (
|
||||
[f"{prefix}{text}" for text in non_empty_texts]
|
||||
if prefix
|
||||
else non_empty_texts
|
||||
)
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = hosted_model.encode(
|
||||
texts, normalize_embeddings=normalize_embeddings
|
||||
embeddings = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Embeddings were not created")
|
||||
raise RuntimeError("Failed to create Embeddings")
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
embeddings_with_nulls: list[list[float] | None] = []
|
||||
current_embedding_index = 0
|
||||
|
||||
for idx in range(len(texts)):
|
||||
if idx in empty_indices:
|
||||
embeddings_with_nulls.append(None)
|
||||
else:
|
||||
embedding = embeddings[current_embedding_index]
|
||||
if isinstance(embedding, list) or embedding is None:
|
||||
embeddings_with_nulls.append(embedding)
|
||||
else:
|
||||
embeddings_with_nulls.append(embedding.tolist())
|
||||
current_embedding_index += 1
|
||||
|
||||
embeddings = embeddings_with_nulls
|
||||
return embeddings
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
@@ -252,7 +325,17 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
|
||||
try:
|
||||
if embed_request.text_type == EmbedTextType.QUERY:
|
||||
prefix = embed_request.manual_query_prefix
|
||||
elif embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
prefix = embed_request.manual_passage_prefix
|
||||
else:
|
||||
prefix = None
|
||||
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
@@ -261,13 +344,13 @@ async def process_embed_request(
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during embedding process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to run Bi-Encoder embedding"
|
||||
)
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
@@ -276,6 +359,11 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
if not embed_request.documents or not embed_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No documents or query to be reranked"
|
||||
)
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
einops==0.8.0
|
||||
fastapi==0.109.2
|
||||
h5py==3.9.0
|
||||
pydantic==1.10.13
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
tensorflow==2.15.0
|
||||
@@ -9,5 +11,5 @@ transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
openai==1.14.3
|
||||
cohere==5.5.8
|
||||
google-cloud-aiplatform==1.58.0
|
||||
cohere==5.6.1
|
||||
google-cloud-aiplatform==1.58.0
|
||||
|
||||
@@ -21,18 +21,17 @@ def run_jobs(exclude_indexing: bool) -> None:
|
||||
cmd_worker = [
|
||||
"celery",
|
||||
"-A",
|
||||
"ee.danswer.background.celery",
|
||||
"ee.danswer.background.celery.celery_app",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--autoscale=3,10",
|
||||
"--concurrency=16",
|
||||
"--loglevel=INFO",
|
||||
"--concurrency=1",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
"celery",
|
||||
"-A",
|
||||
"ee.danswer.background.celery",
|
||||
"ee.danswer.background.celery.celery_app",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
@@ -74,7 +73,7 @@ def run_jobs(exclude_indexing: bool) -> None:
|
||||
try:
|
||||
update_env = os.environ.copy()
|
||||
update_env["PYTHONPATH"] = "."
|
||||
cmd_perm_sync = ["python", "ee.danswer/background/permission_sync.py"]
|
||||
cmd_perm_sync = ["python", "ee/danswer/background/permission_sync.py"]
|
||||
|
||||
indexing_process = subprocess.Popen(
|
||||
cmd_perm_sync,
|
||||
|
||||
@@ -14,7 +14,10 @@ sys.path.append(parent_dir)
|
||||
# flake8: noqa: E402
|
||||
|
||||
# Now import Danswer modules
|
||||
from danswer.db.models import DocumentSet__ConnectorCredentialPair
|
||||
from danswer.db.models import (
|
||||
DocumentSet__ConnectorCredentialPair,
|
||||
UserGroup__ConnectorCredentialPair,
|
||||
)
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.index_attempt import (
|
||||
@@ -44,7 +47,7 @@ logger = setup_logger()
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def unsafe_deletion(
|
||||
def _unsafe_deletion(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
@@ -82,11 +85,22 @@ def unsafe_deletion(
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# Delete document sets + connector / credential Pairs
|
||||
# Delete document sets
|
||||
stmt = delete(DocumentSet__ConnectorCredentialPair).where(
|
||||
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == pair_id
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
# delete user group associations
|
||||
stmt = delete(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.cc_pair_id == pair_id
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
# need to flush to avoid foreign key violations
|
||||
db_session.flush()
|
||||
|
||||
# delete the actual connector credential pair
|
||||
stmt = delete(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
@@ -168,7 +182,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
files_deleted_count = unsafe_deletion(
|
||||
files_deleted_count = _unsafe_deletion(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -4,9 +4,7 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
# This already includes any prefixes, the text is just passed directly to the model
|
||||
texts: list[str]
|
||||
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None
|
||||
max_context_length: int
|
||||
@@ -14,10 +12,12 @@ class EmbedRequest(BaseModel):
|
||||
api_key: str | None
|
||||
provider_type: str | None
|
||||
text_type: EmbedTextType
|
||||
manual_query_prefix: str | None
|
||||
manual_passage_prefix: str | None
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
embeddings: list[list[float]]
|
||||
embeddings: list[list[float] | None]
|
||||
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
@@ -26,7 +26,7 @@ class RerankRequest(BaseModel):
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
scores: list[list[float]]
|
||||
scores: list[list[float] | None]
|
||||
|
||||
|
||||
class IntentRequest(BaseModel):
|
||||
|
||||
@@ -25,7 +25,7 @@ autorestart=true
|
||||
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
||||
# Vespa / Postgres)
|
||||
[program:celery_worker]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --autoscale=3,10 --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --concurrency=16 --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||
stdout_logfile=/var/log/celery_worker_supervisor.log
|
||||
stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
|
||||
@@ -9,7 +9,7 @@ This Python script automates the process of running search quality tests for a b
|
||||
- Manages environment variables
|
||||
- Switches to specified Git branch
|
||||
- Uploads test documents
|
||||
- Runs search quality tests using Relari
|
||||
- Runs search quality tests
|
||||
- Cleans up Docker containers (optional)
|
||||
|
||||
## Usage
|
||||
@@ -29,9 +29,17 @@ export PYTHONPATH=$PYTHONPATH:$PWD/backend
|
||||
```
|
||||
cd backend/tests/regression/answer_quality
|
||||
```
|
||||
7. Run the script:
|
||||
7. To launch the evaluation environment, run the launch_eval_env.py script (this step can be skipped if you are running the env outside of docker, just leave "environment_name" blank):
|
||||
```
|
||||
python run_eval_pipeline.py
|
||||
python launch_eval_env.py
|
||||
```
|
||||
8. Run the file_uploader.py script to upload the zip files located at the path "zipped_documents_file"
|
||||
```
|
||||
python file_uploader.py
|
||||
```
|
||||
9. Run the run_qa.py script to ask questions from the jsonl located at the path "questions_file". This will hit the "query/answer-with-quote" API endpoint.
|
||||
```
|
||||
python run_qa.py
|
||||
```
|
||||
|
||||
Note: All data will be saved even after the containers are shut down. There are instructions below to re-launching docker containers using this data.
|
||||
@@ -61,6 +69,11 @@ Edit `search_test_config.yaml` to set:
|
||||
- Set this to true to automatically delete all docker containers, networks and volumes after the test
|
||||
- launch_web_ui
|
||||
- Set this to true if you want to use the UI during/after the testing process
|
||||
- only_state
|
||||
- Whether to only run Vespa and Postgres
|
||||
- only_retrieve_docs
|
||||
- Set true to only retrieve documents, not LLM response
|
||||
- This is to save on API costs
|
||||
- use_cloud_gpu
|
||||
- Set to true or false depending on if you want to use the remote gpu
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
@@ -70,12 +83,10 @@ Edit `search_test_config.yaml` to set:
|
||||
- model_server_port
|
||||
- This is the port of the remote model server
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
- existing_test_suffix (THIS IS NOT A SUFFIX ANYMORE, TODO UPDATE THE DOCS HERE)
|
||||
- environment_name
|
||||
- Use this if you would like to relaunch a previous test instance
|
||||
- Input the suffix of the test you'd like to re-launch
|
||||
- (E.g. to use the data from folder "test-1234-5678" put "-1234-5678")
|
||||
- No new files will automatically be uploaded
|
||||
- Leave empty to run a new test
|
||||
- Input the env_name of the test you'd like to re-launch
|
||||
- Leave empty to launch referencing local default network locations
|
||||
- limit
|
||||
- Max number of questions you'd like to ask against the dataset
|
||||
- Set to null for no limit
|
||||
@@ -85,7 +96,7 @@ Edit `search_test_config.yaml` to set:
|
||||
|
||||
## Relaunching From Existing Data
|
||||
|
||||
To launch an existing set of containers that has already completed indexing, set the existing_test_suffix variable. This will launch the docker containers mounted on the volumes of the indicated suffix and will not automatically index any documents or run any QA.
|
||||
To launch an existing set of containers that has already completed indexing, set the environment_name variable. This will launch the docker containers mounted on the volumes of the indicated env_name and will not automatically index any documents or run any QA.
|
||||
|
||||
Once these containers are launched you can run file_uploader.py or run_qa.py (assuming you have run the steps in the Usage section above).
|
||||
- file_uploader.py will upload and index additional zipped files located at the zipped_documents_file path.
|
||||
|
||||
@@ -2,45 +2,31 @@ import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def _api_url_builder(run_suffix: str, api_path: str) -> str:
|
||||
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
|
||||
|
||||
|
||||
def _create_new_chat_session(run_suffix: str) -> int:
|
||||
create_chat_request = ChatSessionCreationRequest(
|
||||
persona_id=0,
|
||||
description=None,
|
||||
)
|
||||
body = create_chat_request.dict()
|
||||
|
||||
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
|
||||
|
||||
response_json = requests.post(
|
||||
create_chat_url, headers=GENERAL_HEADERS, json=body
|
||||
).json()
|
||||
chat_session_id = response_json.get("chat_session_id")
|
||||
|
||||
if isinstance(chat_session_id, int):
|
||||
return chat_session_id
|
||||
def _api_url_builder(env_name: str, api_path: str) -> str:
|
||||
if env_name:
|
||||
return f"http://localhost:{get_api_server_host_port(env_name)}" + api_path
|
||||
else:
|
||||
raise RuntimeError(response_json)
|
||||
return "http://localhost:8080" + api_path
|
||||
|
||||
|
||||
@retry(tries=10, delay=10)
|
||||
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||
@retry(tries=5, delay=5)
|
||||
def get_answer_from_query(
|
||||
query: str, only_retrieve_docs: bool, env_name: str
|
||||
) -> tuple[list[str], str]:
|
||||
filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=None,
|
||||
@@ -48,42 +34,47 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||
tags=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
retrieval_options = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
|
||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||
|
||||
new_message_request = DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=0,
|
||||
persona_id=0,
|
||||
retrieval_options=RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
),
|
||||
chain_of_thought=False,
|
||||
return_contexts=True,
|
||||
skip_gen_ai_answer_generation=only_retrieve_docs,
|
||||
)
|
||||
|
||||
chat_session_id = _create_new_chat_session(run_suffix)
|
||||
|
||||
url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
|
||||
|
||||
new_message_request = BasicCreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
message=query,
|
||||
retrieval_options=retrieval_options,
|
||||
query_override=query,
|
||||
)
|
||||
url = _api_url_builder(env_name, "/query/answer-with-quote/")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
body = new_message_request.dict()
|
||||
body["user"] = None
|
||||
try:
|
||||
response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
|
||||
simple_search_docs = response_json.get("simple_search_docs", [])
|
||||
answer = response_json.get("answer", "")
|
||||
response_json = requests.post(url, headers=headers, json=body).json()
|
||||
context_data_list = response_json.get("contexts", {}).get("contexts", [])
|
||||
answer = response_json.get("answer", "") or ""
|
||||
except Exception as e:
|
||||
print("Failed to answer the questions:")
|
||||
print(f"\t {str(e)}")
|
||||
print("trying again")
|
||||
print("Try restarting vespa container and trying agian")
|
||||
raise e
|
||||
|
||||
return simple_search_docs, answer
|
||||
return context_data_list, answer
|
||||
|
||||
|
||||
@retry(tries=10, delay=10)
|
||||
def check_if_query_ready(run_suffix: str) -> bool:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
|
||||
def check_indexing_status(env_name: str) -> tuple[int, bool]:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector/indexing-status/")
|
||||
try:
|
||||
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
|
||||
except Exception as e:
|
||||
@@ -98,20 +89,21 @@ def check_if_query_ready(run_suffix: str) -> bool:
|
||||
status = index_attempt["last_status"]
|
||||
if status == IndexingStatus.IN_PROGRESS or status == IndexingStatus.NOT_STARTED:
|
||||
ongoing_index_attempts = True
|
||||
elif status == IndexingStatus.SUCCESS:
|
||||
doc_count += 16
|
||||
doc_count += index_attempt["docs_indexed"]
|
||||
doc_count -= 16
|
||||
|
||||
if not doc_count:
|
||||
print("No docs indexed, waiting for indexing to start")
|
||||
elif ongoing_index_attempts:
|
||||
print(
|
||||
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
|
||||
)
|
||||
|
||||
return doc_count > 0 and not ongoing_index_attempts
|
||||
# all the +16 and -16 are to account for the fact that the indexing status
|
||||
# is only updated every 16 documents and will tells us how many are
|
||||
# chunked, not indexed. probably need to fix this. in the future!
|
||||
if doc_count:
|
||||
doc_count += 16
|
||||
return doc_count, ongoing_index_attempts
|
||||
|
||||
|
||||
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
|
||||
def run_cc_once(env_name: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector/run-once/")
|
||||
body = {
|
||||
"connector_id": connector_id,
|
||||
"credential_ids": [credential_id],
|
||||
@@ -126,9 +118,9 @@ def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
print("Failed text:", response.text)
|
||||
|
||||
|
||||
def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(
|
||||
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||
env_name, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||
)
|
||||
|
||||
body = {"name": "zip_folder_contents", "is_public": True}
|
||||
@@ -141,8 +133,8 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
|
||||
print("Failed text:", response.text)
|
||||
|
||||
|
||||
def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||
url = _api_url_builder(run_suffix, "/manage/connector")
|
||||
def _get_existing_connector_names(env_name: str) -> list[str]:
|
||||
url = _api_url_builder(env_name, "/manage/connector")
|
||||
|
||||
body = {
|
||||
"credential_json": {},
|
||||
@@ -156,10 +148,10 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||
raise RuntimeError(response.__dict__)
|
||||
|
||||
|
||||
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector")
|
||||
def create_connector(env_name: str, file_paths: list[str]) -> int:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector")
|
||||
connector_name = base_connector_name = "search_eval_connector"
|
||||
existing_connector_names = _get_existing_connector_names(run_suffix)
|
||||
existing_connector_names = _get_existing_connector_names(env_name)
|
||||
|
||||
count = 1
|
||||
while connector_name in existing_connector_names:
|
||||
@@ -186,8 +178,8 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
raise RuntimeError(response.__dict__)
|
||||
|
||||
|
||||
def create_credential(run_suffix: str) -> int:
|
||||
url = _api_url_builder(run_suffix, "/manage/credential")
|
||||
def create_credential(env_name: str) -> int:
|
||||
url = _api_url_builder(env_name, "/manage/credential")
|
||||
body = {
|
||||
"credential_json": {},
|
||||
"admin_public": True,
|
||||
@@ -201,12 +193,12 @@ def create_credential(run_suffix: str) -> int:
|
||||
|
||||
|
||||
@retry(tries=10, delay=2, backoff=2)
|
||||
def upload_file(run_suffix: str, zip_file_path: str) -> list[str]:
|
||||
def upload_file(env_name: str, zip_file_path: str) -> list[str]:
|
||||
files = [
|
||||
("files", open(zip_file_path, "rb")),
|
||||
]
|
||||
|
||||
api_path = _api_url_builder(run_suffix, "/manage/admin/connector/file/upload")
|
||||
api_path = _api_url_builder(env_name, "/manage/admin/connector/file/upload")
|
||||
try:
|
||||
response = requests.post(api_path, files=files)
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
|
||||
@@ -67,20 +67,20 @@ def switch_to_commit(commit_sha: str) -> None:
|
||||
print("Repository updated successfully.")
|
||||
|
||||
|
||||
def get_docker_container_env_vars(suffix: str) -> dict:
|
||||
def get_docker_container_env_vars(env_name: str) -> dict:
|
||||
"""
|
||||
Retrieves environment variables from "background" and "api_server" Docker containers.
|
||||
"""
|
||||
print(f"Getting environment variables for containers with suffix: {suffix}")
|
||||
print(f"Getting environment variables for containers with env_name: {env_name}")
|
||||
|
||||
combined_env_vars = {}
|
||||
for container_type in ["background", "api_server"]:
|
||||
container_name = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{suffix}/'"
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{env_name}/'"
|
||||
)[0].strip()
|
||||
if not container_name:
|
||||
raise RuntimeError(
|
||||
f"No {container_type} container found with suffix: {suffix}"
|
||||
f"No {container_type} container found with env_name: {env_name}"
|
||||
)
|
||||
|
||||
env_vars_json = _run_command(
|
||||
@@ -95,9 +95,9 @@ def get_docker_container_env_vars(suffix: str) -> dict:
|
||||
return combined_env_vars
|
||||
|
||||
|
||||
def manage_data_directories(suffix: str, base_path: str, use_cloud_gpu: bool) -> None:
|
||||
def manage_data_directories(env_name: str, base_path: str, use_cloud_gpu: bool) -> None:
|
||||
# Use the user's home directory as the base path
|
||||
target_path = os.path.join(os.path.expanduser(base_path), suffix)
|
||||
target_path = os.path.join(os.path.expanduser(base_path), env_name)
|
||||
directories = {
|
||||
"DANSWER_POSTGRES_DATA_DIR": os.path.join(target_path, "postgres/"),
|
||||
"DANSWER_VESPA_DATA_DIR": os.path.join(target_path, "vespa/"),
|
||||
@@ -144,12 +144,12 @@ def _is_port_in_use(port: int) -> bool:
|
||||
|
||||
|
||||
def start_docker_compose(
|
||||
run_suffix: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
|
||||
env_name: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
|
||||
) -> None:
|
||||
print("Starting Docker Compose...")
|
||||
os.chdir(os.path.dirname(__file__))
|
||||
os.chdir("../../../../deployment/docker_compose/")
|
||||
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{run_suffix} up -d"
|
||||
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{env_name} up -d"
|
||||
command += " --build"
|
||||
command += " --force-recreate"
|
||||
|
||||
@@ -175,17 +175,17 @@ def start_docker_compose(
|
||||
print("Containers have been launched")
|
||||
|
||||
|
||||
def cleanup_docker(run_suffix: str) -> None:
|
||||
def cleanup_docker(env_name: str) -> None:
|
||||
print(
|
||||
f"Deleting Docker containers, volumes, and networks for project suffix: {run_suffix}"
|
||||
f"Deleting Docker containers, volumes, and networks for project env_name: {env_name}"
|
||||
)
|
||||
|
||||
stdout, _ = _run_command("docker ps -a --format '{{json .}}'")
|
||||
|
||||
containers = [json.loads(line) for line in stdout.splitlines()]
|
||||
if not run_suffix:
|
||||
run_suffix = datetime.now().strftime("-%Y")
|
||||
project_name = f"danswer-stack{run_suffix}"
|
||||
if not env_name:
|
||||
env_name = datetime.now().strftime("-%Y")
|
||||
project_name = f"danswer-stack{env_name}"
|
||||
containers_to_delete = [
|
||||
c for c in containers if c["Names"].startswith(project_name)
|
||||
]
|
||||
@@ -221,23 +221,23 @@ def cleanup_docker(run_suffix: str) -> None:
|
||||
|
||||
networks = stdout.splitlines()
|
||||
|
||||
networks_to_delete = [n for n in networks if run_suffix in n]
|
||||
networks_to_delete = [n for n in networks if env_name in n]
|
||||
|
||||
if not networks_to_delete:
|
||||
print(f"No networks found containing suffix: {run_suffix}")
|
||||
print(f"No networks found containing env_name: {env_name}")
|
||||
else:
|
||||
network_names = " ".join(networks_to_delete)
|
||||
_run_command(f"docker network rm {network_names}")
|
||||
|
||||
print(
|
||||
f"Successfully deleted {len(networks_to_delete)} networks containing suffix: {run_suffix}"
|
||||
f"Successfully deleted {len(networks_to_delete)} networks containing env_name: {env_name}"
|
||||
)
|
||||
|
||||
|
||||
@retry(tries=5, delay=5, backoff=2)
|
||||
def get_api_server_host_port(suffix: str) -> str:
|
||||
def get_api_server_host_port(env_name: str) -> str:
|
||||
"""
|
||||
This pulls all containers with the provided suffix
|
||||
This pulls all containers with the provided env_name
|
||||
It then grabs the JSON specific container with a name containing "api_server"
|
||||
It then grabs the port info from the JSON and strips out the relevent data
|
||||
"""
|
||||
@@ -248,16 +248,16 @@ def get_api_server_host_port(suffix: str) -> str:
|
||||
server_jsons = []
|
||||
|
||||
for container in containers:
|
||||
if container_name in container["Names"] and suffix in container["Names"]:
|
||||
if container_name in container["Names"] and env_name in container["Names"]:
|
||||
server_jsons.append(container)
|
||||
|
||||
if not server_jsons:
|
||||
raise RuntimeError(
|
||||
f"No container found containing: {container_name} and {suffix}"
|
||||
f"No container found containing: {container_name} and {env_name}"
|
||||
)
|
||||
elif len(server_jsons) > 1:
|
||||
raise RuntimeError(
|
||||
f"Too many containers matching {container_name} found, please indicate a suffix"
|
||||
f"Too many containers matching {container_name} found, please indicate a env_name"
|
||||
)
|
||||
server_json = server_jsons[0]
|
||||
|
||||
@@ -278,67 +278,37 @@ def get_api_server_host_port(suffix: str) -> str:
|
||||
raise RuntimeError(f"Too many ports matching {client_port} found")
|
||||
if not matching_ports:
|
||||
raise RuntimeError(
|
||||
f"No port found containing: {client_port} for container: {container_name} and suffix: {suffix}"
|
||||
f"No port found containing: {client_port} for container: {container_name} and env_name: {env_name}"
|
||||
)
|
||||
return matching_ports[0]
|
||||
|
||||
|
||||
# Added function to check Vespa container health status
|
||||
def is_vespa_container_healthy(suffix: str) -> bool:
|
||||
print(f"Checking health status of Vespa container for suffix: {suffix}")
|
||||
|
||||
# Find the Vespa container
|
||||
stdout, _ = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
|
||||
)
|
||||
container_name = stdout.strip()
|
||||
|
||||
if not container_name:
|
||||
print(f"No Vespa container found with suffix: {suffix}")
|
||||
return False
|
||||
|
||||
# Get the health status
|
||||
stdout, _ = _run_command(
|
||||
f"docker inspect --format='{{{{.State.Health.Status}}}}' {container_name}"
|
||||
)
|
||||
health_status = stdout.strip()
|
||||
|
||||
is_healthy = health_status.lower() == "healthy"
|
||||
print(f"Vespa container '{container_name}' health status: {health_status}")
|
||||
|
||||
return is_healthy
|
||||
|
||||
|
||||
# Added function to restart Vespa container
|
||||
def restart_vespa_container(suffix: str) -> None:
|
||||
print(f"Restarting Vespa container for suffix: {suffix}")
|
||||
def restart_vespa_container(env_name: str) -> None:
|
||||
print(f"Restarting Vespa container for env_name: {env_name}")
|
||||
|
||||
# Find the Vespa container
|
||||
stdout, _ = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/index-1/ && /{env_name}/'"
|
||||
)
|
||||
container_name = stdout.strip()
|
||||
|
||||
if not container_name:
|
||||
raise RuntimeError(f"No Vespa container found with suffix: {suffix}")
|
||||
raise RuntimeError(f"No Vespa container found with env_name: {env_name}")
|
||||
|
||||
# Restart the container
|
||||
_run_command(f"docker restart {container_name}")
|
||||
|
||||
print(f"Vespa container '{container_name}' has begun restarting")
|
||||
|
||||
time_to_wait = 5
|
||||
while not is_vespa_container_healthy(suffix):
|
||||
print(f"Waiting {time_to_wait} seconds for vespa container to restart")
|
||||
time.sleep(5)
|
||||
|
||||
time.sleep(30)
|
||||
print(f"Vespa container '{container_name}' has been restarted")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Running this just cleans up the docker environment for the container indicated by existing_test_suffix
|
||||
If no existing_test_suffix is indicated, will just clean up all danswer docker containers/volumes/networks
|
||||
Running this just cleans up the docker environment for the container indicated by environment_name
|
||||
If no environment_name is indicated, will just clean up all danswer docker containers/volumes/networks
|
||||
Note: vespa/postgres mounts are not deleted
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -348,4 +318,4 @@ if __name__ == "__main__":
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError("config must be a dictionary")
|
||||
cleanup_docker(config["existing_test_suffix"])
|
||||
cleanup_docker(config["environment_name"])
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.regression.answer_quality.api_utils import check_indexing_status
|
||||
from tests.regression.answer_quality.api_utils import create_cc_pair
|
||||
from tests.regression.answer_quality.api_utils import create_connector
|
||||
from tests.regression.answer_quality.api_utils import create_credential
|
||||
@@ -10,15 +15,65 @@ from tests.regression.answer_quality.api_utils import run_cc_once
|
||||
from tests.regression.answer_quality.api_utils import upload_file
|
||||
|
||||
|
||||
def upload_test_files(zip_file_path: str, run_suffix: str) -> None:
|
||||
def unzip_and_get_file_paths(zip_file_path: str) -> list[str]:
|
||||
persistent_dir = tempfile.mkdtemp()
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(persistent_dir)
|
||||
|
||||
return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()]
|
||||
|
||||
|
||||
def create_temp_zip_from_files(file_paths: list[str]) -> str:
|
||||
persistent_dir = tempfile.mkdtemp()
|
||||
zip_file_path = os.path.join(persistent_dir, "temp.zip")
|
||||
|
||||
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
|
||||
for file_path in file_paths:
|
||||
zip_file.write(file_path, Path(file_path).name)
|
||||
|
||||
return zip_file_path
|
||||
|
||||
|
||||
def upload_test_files(zip_file_path: str, env_name: str) -> None:
|
||||
print("zip:", zip_file_path)
|
||||
file_paths = upload_file(run_suffix, zip_file_path)
|
||||
file_paths = upload_file(env_name, zip_file_path)
|
||||
|
||||
conn_id = create_connector(run_suffix, file_paths)
|
||||
cred_id = create_credential(run_suffix)
|
||||
conn_id = create_connector(env_name, file_paths)
|
||||
cred_id = create_credential(env_name)
|
||||
|
||||
create_cc_pair(run_suffix, conn_id, cred_id)
|
||||
run_cc_once(run_suffix, conn_id, cred_id)
|
||||
create_cc_pair(env_name, conn_id, cred_id)
|
||||
run_cc_once(env_name, conn_id, cred_id)
|
||||
|
||||
|
||||
def manage_file_upload(zip_file_path: str, env_name: str) -> None:
|
||||
unzipped_file_paths = unzip_and_get_file_paths(zip_file_path)
|
||||
total_file_count = len(unzipped_file_paths)
|
||||
|
||||
while True:
|
||||
doc_count, ongoing_index_attempts = check_indexing_status(env_name)
|
||||
|
||||
if ongoing_index_attempts:
|
||||
print(
|
||||
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
|
||||
)
|
||||
elif not doc_count:
|
||||
print("No docs indexed, waiting for indexing to start")
|
||||
upload_test_files(zip_file_path, env_name)
|
||||
elif doc_count < total_file_count:
|
||||
print(f"No ongooing indexing attempts but only {doc_count} docs indexed")
|
||||
remaining_files = unzipped_file_paths[doc_count:]
|
||||
print(f"Grabbed last {len(remaining_files)} docs to try agian")
|
||||
temp_zip_file_path = create_temp_zip_from_files(remaining_files)
|
||||
upload_test_files(temp_zip_file_path, env_name)
|
||||
os.unlink(temp_zip_file_path)
|
||||
else:
|
||||
print(f"Successfully uploaded {doc_count} docs!")
|
||||
break
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
for file in unzipped_file_paths:
|
||||
os.unlink(file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -27,5 +82,5 @@ if __name__ == "__main__":
|
||||
with open(config_path, "r") as file:
|
||||
config = SimpleNamespace(**yaml.safe_load(file))
|
||||
file_location = config.zipped_documents_file
|
||||
run_suffix = config.existing_test_suffix
|
||||
upload_test_files(file_location, run_suffix)
|
||||
env_name = config.environment_name
|
||||
manage_file_upload(file_location, env_name)
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.regression.answer_quality.cli_utils import cleanup_docker
|
||||
from tests.regression.answer_quality.cli_utils import manage_data_directories
|
||||
from tests.regression.answer_quality.cli_utils import set_env_variables
|
||||
from tests.regression.answer_quality.cli_utils import start_docker_compose
|
||||
from tests.regression.answer_quality.cli_utils import switch_to_commit
|
||||
from tests.regression.answer_quality.file_uploader import upload_test_files
|
||||
from tests.regression.answer_quality.run_qa import run_qa_test_and_save_results
|
||||
|
||||
|
||||
def load_config(config_filename: str) -> SimpleNamespace:
|
||||
@@ -22,12 +18,16 @@ def load_config(config_filename: str) -> SimpleNamespace:
|
||||
|
||||
def main() -> None:
|
||||
config = load_config("search_test_config.yaml")
|
||||
if config.existing_test_suffix:
|
||||
run_suffix = config.existing_test_suffix
|
||||
print("launching danswer with existing data suffix:", run_suffix)
|
||||
if config.environment_name:
|
||||
env_name = config.environment_name
|
||||
print("launching danswer with environment name:", env_name)
|
||||
else:
|
||||
run_suffix = datetime.now().strftime("-%Y%m%d-%H%M%S")
|
||||
print("run_suffix:", run_suffix)
|
||||
print("No env name defined. Not launching docker.")
|
||||
print(
|
||||
"Please define a name in the config yaml to start a new env "
|
||||
"or use an existing env"
|
||||
)
|
||||
return
|
||||
|
||||
set_env_variables(
|
||||
config.model_server_ip,
|
||||
@@ -35,22 +35,14 @@ def main() -> None:
|
||||
config.use_cloud_gpu,
|
||||
config.llm,
|
||||
)
|
||||
manage_data_directories(run_suffix, config.output_folder, config.use_cloud_gpu)
|
||||
manage_data_directories(env_name, config.output_folder, config.use_cloud_gpu)
|
||||
if config.commit_sha:
|
||||
switch_to_commit(config.commit_sha)
|
||||
|
||||
start_docker_compose(
|
||||
run_suffix, config.launch_web_ui, config.use_cloud_gpu, config.only_state
|
||||
env_name, config.launch_web_ui, config.use_cloud_gpu, config.only_state
|
||||
)
|
||||
|
||||
if not config.existing_test_suffix and not config.only_state:
|
||||
upload_test_files(config.zipped_documents_file, run_suffix)
|
||||
|
||||
run_qa_test_and_save_results(run_suffix)
|
||||
|
||||
if config.clean_up_docker_containers:
|
||||
cleanup_docker(run_suffix)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,7 +6,6 @@ import time
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.regression.answer_quality.api_utils import check_if_query_ready
|
||||
from tests.regression.answer_quality.api_utils import get_answer_from_query
|
||||
from tests.regression.answer_quality.cli_utils import get_current_commit_sha
|
||||
from tests.regression.answer_quality.cli_utils import get_docker_container_env_vars
|
||||
@@ -44,12 +43,12 @@ def _read_questions_jsonl(questions_file_path: str) -> list[dict]:
|
||||
|
||||
def _get_test_output_folder(config: dict) -> str:
|
||||
base_output_folder = os.path.expanduser(config["output_folder"])
|
||||
if config["run_suffix"]:
|
||||
if config["env_name"]:
|
||||
base_output_folder = os.path.join(
|
||||
base_output_folder, config["run_suffix"], "evaluations_output"
|
||||
base_output_folder, config["env_name"], "evaluations_output"
|
||||
)
|
||||
else:
|
||||
base_output_folder = os.path.join(base_output_folder, "no_defined_suffix")
|
||||
base_output_folder = os.path.join(base_output_folder, "no_defined_env_name")
|
||||
|
||||
counter = 1
|
||||
output_folder_path = os.path.join(base_output_folder, "run_1")
|
||||
@@ -73,12 +72,12 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
||||
|
||||
metadata = {
|
||||
"commit_sha": get_current_commit_sha(),
|
||||
"run_suffix": config["run_suffix"],
|
||||
"env_name": config["env_name"],
|
||||
"test_config": config,
|
||||
"number_of_questions_in_dataset": len(questions),
|
||||
}
|
||||
|
||||
env_vars = get_docker_container_env_vars(config["run_suffix"])
|
||||
env_vars = get_docker_container_env_vars(config["env_name"])
|
||||
if env_vars["ENV_SEED_CONFIGURATION"]:
|
||||
del env_vars["ENV_SEED_CONFIGURATION"]
|
||||
if env_vars["GPG_KEY"]:
|
||||
@@ -118,7 +117,8 @@ def _process_question(question_data: dict, config: dict, question_number: int) -
|
||||
print(f"query: {query}")
|
||||
context_data_list, answer = get_answer_from_query(
|
||||
query=query,
|
||||
run_suffix=config["run_suffix"],
|
||||
only_retrieve_docs=config["only_retrieve_docs"],
|
||||
env_name=config["env_name"],
|
||||
)
|
||||
|
||||
if not context_data_list:
|
||||
@@ -142,27 +142,23 @@ def _process_and_write_query_results(config: dict) -> None:
|
||||
test_output_folder, questions = _initialize_files(config)
|
||||
print("saving test results to folder:", test_output_folder)
|
||||
|
||||
while not check_if_query_ready(config["run_suffix"]):
|
||||
time.sleep(5)
|
||||
|
||||
if config["limit"] is not None:
|
||||
questions = questions[: config["limit"]]
|
||||
|
||||
with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
|
||||
# Use multiprocessing to process questions
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.starmap(
|
||||
_process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
|
||||
_process_question,
|
||||
[(question, config, i + 1) for i, question in enumerate(questions)],
|
||||
)
|
||||
|
||||
_populate_results_file(test_output_folder, results)
|
||||
|
||||
invalid_answer_count = 0
|
||||
for result in results:
|
||||
if not result.get("answer"):
|
||||
if len(result["context_data_list"]) == 0:
|
||||
invalid_answer_count += 1
|
||||
|
||||
if not result.get("context_data_list"):
|
||||
raise RuntimeError("Search failed, this is a critical failure!")
|
||||
|
||||
_update_metadata_file(test_output_folder, invalid_answer_count)
|
||||
|
||||
if invalid_answer_count:
|
||||
@@ -177,7 +173,7 @@ def _process_and_write_query_results(config: dict) -> None:
|
||||
print("saved test results to folder:", test_output_folder)
|
||||
|
||||
|
||||
def run_qa_test_and_save_results(run_suffix: str = "") -> None:
|
||||
def run_qa_test_and_save_results(env_name: str = "") -> None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(current_dir, "search_test_config.yaml")
|
||||
with open(config_path, "r") as file:
|
||||
@@ -186,16 +182,16 @@ def run_qa_test_and_save_results(run_suffix: str = "") -> None:
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError("config must be a dictionary")
|
||||
|
||||
if not run_suffix:
|
||||
run_suffix = config["existing_test_suffix"]
|
||||
if not env_name:
|
||||
env_name = config["environment_name"]
|
||||
|
||||
config["run_suffix"] = run_suffix
|
||||
config["env_name"] = env_name
|
||||
_process_and_write_query_results(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
To run a different set of questions, update the questions_file in search_test_config.yaml
|
||||
If there is more than one instance of Danswer running, specify the suffix in search_test_config.yaml
|
||||
If there is more than one instance of Danswer running, specify the env_name in search_test_config.yaml
|
||||
"""
|
||||
run_qa_test_and_save_results()
|
||||
|
||||
@@ -13,14 +13,11 @@ questions_file: "~/sample_questions.yaml"
|
||||
# Git commit SHA to use (null means use current code as is)
|
||||
commit_sha: null
|
||||
|
||||
# Whether to remove Docker containers after the test
|
||||
clean_up_docker_containers: true
|
||||
|
||||
# Whether to launch a web UI for the test
|
||||
launch_web_ui: false
|
||||
|
||||
# Whether to only run Vespa and Postgres
|
||||
only_state: false
|
||||
# Only retrieve documents, not LLM response
|
||||
only_retrieve_docs: false
|
||||
|
||||
# Whether to use a cloud GPU for processing
|
||||
use_cloud_gpu: false
|
||||
@@ -31,9 +28,8 @@ model_server_ip: "PUT_PUBLIC_CLOUD_IP_HERE"
|
||||
# Port of the model server (placeholder)
|
||||
model_server_port: "PUT_PUBLIC_CLOUD_PORT_HERE"
|
||||
|
||||
# Suffix for existing test results (E.g. -1234-5678)
|
||||
# empty string means no suffix
|
||||
existing_test_suffix: ""
|
||||
# Name for existing testing env (empty string uses default ports)
|
||||
environment_name: ""
|
||||
|
||||
# Limit on number of tests to run (null means no limit)
|
||||
limit: null
|
||||
|
||||
@@ -31,8 +31,8 @@ def test_chunk_document() -> None:
|
||||
|
||||
chunks = chunk_document(document)
|
||||
assert len(chunks) == 5
|
||||
assert all(semantic_identifier in chunk.content for chunk in chunks)
|
||||
assert short_section_1 in chunks[0].content
|
||||
assert short_section_3 in chunks[-1].content
|
||||
assert short_section_4 in chunks[-1].content
|
||||
assert "tag1" in chunks[0].content
|
||||
assert "tag1" in chunks[0].metadata_suffix_keyword
|
||||
assert "tag2" in chunks[0].metadata_suffix_semantic
|
||||
|
||||
@@ -25,8 +25,8 @@ services:
|
||||
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
||||
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
||||
- REQUIRE_EMAIL_VERIFICATION=${REQUIRE_EMAIL_VERIFICATION:-}
|
||||
- SMTP_SERVER=${SMTP_SERVER:-} # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com'
|
||||
- SMTP_PORT=${SMTP_PORT:-587} # For sending verification emails, if unspecified then defaults to '587'
|
||||
- SMTP_SERVER=${SMTP_SERVER:-} # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com'
|
||||
- SMTP_PORT=${SMTP_PORT:-587} # For sending verification emails, if unspecified then defaults to '587'
|
||||
- SMTP_USER=${SMTP_USER:-}
|
||||
- SMTP_PASS=${SMTP_PASS:-}
|
||||
- EMAIL_FROM=${EMAIL_FROM:-}
|
||||
@@ -59,8 +59,8 @@ services:
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
|
||||
@@ -69,7 +69,7 @@ services:
|
||||
# Other services
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
|
||||
# Don't change the NLP model configs unless you know what you're doing
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
|
||||
@@ -104,7 +104,6 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
background:
|
||||
image: danswer/danswer-backend:latest
|
||||
build:
|
||||
@@ -139,8 +138,8 @@ services:
|
||||
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
|
||||
- BING_API_KEY=${BING_API_KEY:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
|
||||
@@ -152,12 +151,12 @@ services:
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- VESPA_HOST=index
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors
|
||||
# Don't change the NLP model configs unless you know what you're doing
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
|
||||
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
@@ -183,7 +182,7 @@ services:
|
||||
- DANSWER_BOT_FEEDBACK_VISIBILITY=${DANSWER_BOT_FEEDBACK_VISIBILITY:-}
|
||||
- DANSWER_BOT_DISPLAY_ERROR_MSGS=${DANSWER_BOT_DISPLAY_ERROR_MSGS:-}
|
||||
- DANSWER_BOT_RESPOND_EVERY_CHANNEL=${DANSWER_BOT_RESPOND_EVERY_CHANNEL:-}
|
||||
- DANSWER_BOT_DISABLE_COT=${DANSWER_BOT_DISABLE_COT:-} # Currently unused
|
||||
- DANSWER_BOT_DISABLE_COT=${DANSWER_BOT_DISABLE_COT:-} # Currently unused
|
||||
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
||||
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
|
||||
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
|
||||
@@ -207,7 +206,6 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
web_server:
|
||||
image: danswer/danswer-web-server:latest
|
||||
build:
|
||||
@@ -219,6 +217,7 @@ services:
|
||||
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
|
||||
# Enterprise Edition only
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
@@ -236,7 +235,6 @@ services:
|
||||
# Enterprise Edition only
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
|
||||
|
||||
inference_model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
@@ -263,7 +261,6 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
indexing_model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
@@ -291,7 +288,6 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
restart: always
|
||||
@@ -303,7 +299,6 @@ services:
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
@@ -319,7 +314,6 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
nginx:
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: always
|
||||
@@ -332,7 +326,7 @@ services:
|
||||
- DOMAIN=localhost
|
||||
ports:
|
||||
- "80:80"
|
||||
- "3000:80" # allow for localhost:3000 usage, since that is the norm
|
||||
- "3000:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
logging:
|
||||
@@ -349,10 +343,9 @@ services:
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
|
||||
|
||||
|
||||
volumes:
|
||||
db_volume:
|
||||
vespa_volume:
|
||||
# Created by the container itself
|
||||
vespa_volume: # Created by the container itself
|
||||
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
|
||||
@@ -211,6 +211,7 @@ services:
|
||||
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
depends_on:
|
||||
- api_server
|
||||
|
||||
@@ -70,6 +70,7 @@ services:
|
||||
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
depends_on:
|
||||
- api_server
|
||||
|
||||
@@ -75,6 +75,7 @@ services:
|
||||
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
|
||||
# Enterprise Edition only
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
|
||||
@@ -46,6 +46,9 @@ ENV NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PRED
|
||||
ARG NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS
|
||||
ENV NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS}
|
||||
|
||||
ARG NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN
|
||||
ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
|
||||
|
||||
ARG NEXT_PUBLIC_THEME
|
||||
ENV NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME}
|
||||
|
||||
@@ -106,6 +109,9 @@ ENV NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME}
|
||||
ARG NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED
|
||||
ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED}
|
||||
|
||||
ARG NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN
|
||||
ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
|
||||
|
||||
ARG NEXT_PUBLIC_DISABLE_LOGOUT
|
||||
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
||||
|
||||
|
||||
@@ -436,7 +436,7 @@ export function AssistantEditor({
|
||||
<div className="mb-6">
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
LLM Provider{" "}
|
||||
LLM Override{" "}
|
||||
</div>
|
||||
<TooltipProvider delayDuration={50}>
|
||||
<Tooltip>
|
||||
@@ -452,6 +452,10 @@ export function AssistantEditor({
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<p className="my-1 text-text-600">
|
||||
You assistant will use your system default (currently{" "}
|
||||
{defaultModelName}) unless otherwise specified below.
|
||||
</p>
|
||||
<div className="mb-2 flex items-starts">
|
||||
<div className="w-96">
|
||||
<SelectorFormField
|
||||
|
||||
@@ -136,10 +136,6 @@ export default function Page({ params }: { params: { ccPairId: string } }) {
|
||||
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<Main ccPairId={ccPairId} />
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -248,10 +248,6 @@ const MainSection = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<AxeroIcon size={32} />} title="Axero" />
|
||||
|
||||
<MainSection />
|
||||
|
||||
@@ -249,10 +249,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<BookstackIcon size={32} />} title="Bookstack" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -331,10 +331,6 @@ const MainSection = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<ClickupIcon size={32} />} title="Clickup" />
|
||||
|
||||
<MainSection />
|
||||
|
||||
@@ -330,10 +330,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<ConfluenceIcon size={32} />} title="Confluence" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -273,10 +273,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<DiscourseIcon size={32} />} title="Discourse" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -262,10 +262,6 @@ const MainSection = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle
|
||||
icon={<Document360Icon size={32} />}
|
||||
title="Document360"
|
||||
|
||||
@@ -212,9 +212,7 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
{" "}
|
||||
<AdminPageTitle icon={<DropboxIcon size={32} />} title="Dropbox" />
|
||||
<Main />
|
||||
</div>
|
||||
|
||||
@@ -287,10 +287,6 @@ const Main = () => {
|
||||
export default function File() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<FileIcon size={32} />} title="File" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -265,10 +265,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="container mx-auto">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle
|
||||
icon={<GithubIcon size={32} />}
|
||||
title="Github PRs + Issues"
|
||||
|
||||
@@ -250,10 +250,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="container mx-auto">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle
|
||||
icon={<GitlabIcon size={32} />}
|
||||
title="Gitlab MRs + Issues"
|
||||
|
||||
@@ -264,10 +264,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<GmailIcon size={32} />} title="Gmail" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -257,10 +257,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<GongIcon size={32} />} title="Gong" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -412,10 +412,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle
|
||||
icon={<GoogleDriveIcon size={32} />}
|
||||
title="Google Drive"
|
||||
|
||||
@@ -50,10 +50,6 @@ export default function GoogleSites() {
|
||||
{popup}
|
||||
{filesAreUploading && <Spinner />}
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle
|
||||
icon={<GoogleSitesIcon size={32} />}
|
||||
title="Google Sites"
|
||||
|
||||
@@ -244,9 +244,7 @@ const GCSMain = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
{" "}
|
||||
<AdminPageTitle
|
||||
icon={<GoogleStorageIcon size={32} />}
|
||||
title="Google Cloud Storage"
|
||||
|
||||
@@ -232,10 +232,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<GuruIcon size={32} />} title="Guru" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -220,10 +220,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<HubSpotIcon size={32} />} title="HubSpot" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -362,10 +362,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<JiraIcon size={32} />} title="Jira" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -224,10 +224,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<LinearIcon size={32} />} title="Linear" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -250,9 +250,7 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
{" "}
|
||||
<div className="border-solid border-gray-600 border-b mb-4 pb-2 flex">
|
||||
<LoopioIcon size={32} />
|
||||
<h1 className="text-3xl font-bold pl-2">Loopio</h1>
|
||||
|
||||
@@ -207,10 +207,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<MediaWikiIcon size={32} />} title="MediaWiki" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -260,10 +260,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle icon={<NotionIcon size={32} />} title="Notion" />
|
||||
|
||||
<Main />
|
||||
|
||||
@@ -259,9 +259,7 @@ const OCIMain = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
{" "}
|
||||
<AdminPageTitle
|
||||
icon={<OCIStorageIcon size={32} />}
|
||||
title="Oracle Cloud Infrastructure"
|
||||
|
||||
@@ -239,10 +239,6 @@ const Main = () => {
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle
|
||||
icon={<ProductboardIcon size={32} />}
|
||||
title="Productboard"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user