Compare commits

..

44 Commits

Author SHA1 Message Date
Yuhong Sun
9d48c79de9 k 2024-07-24 20:55:57 -07:00
hagen-danswer
a4d71e08aa Added check for unknown tool names (#1924)
* answer.py

* Let it continue if broken
2024-07-25 00:19:08 +00:00
rkuo-danswer
546bfbd24b autoscale with pool=thread crashes celery. remove and use concurrency… (#1929)
* autoscale with pool=thread crashes celery. remove and use concurrency instead (to be improved later)

* update dev background script as well
2024-07-25 00:15:27 +00:00
hagen-danswer
27824d6cc6 Fixed login issue (#1920)
* included check for existing emails

* cleaned up logic
2024-07-25 00:03:29 +00:00
Weves
9d5c4ad634 Small fix for non tool calling LLMs 2024-07-24 15:41:43 -07:00
Shukant Pal
9b32003816 Handle SSL error tracebacks in site indexing connector (#1911)
My website (https://shukantpal.com) uses Let's Encrypt certificates, which aren't accepted by the Python urllib certificate verifier for some reason. My website is set up correctly otherwise (https://www.sslshopper.com/ssl-checker.html#hostname=www.shukantpal.com)

This change adds a fix so the correct traceback is shown in Danswer, instead of a generic "unable to connect, check your Internet connection".
2024-07-24 22:36:29 +00:00
pablodanswer
8bc4123ed7 add modern health check banner + expiration tracking (#1730)
---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-07-24 15:34:22 -07:00
pablodanswer
d58aaf7a59 add href 2024-07-24 14:33:56 -07:00
pablodanswer
a0056a1b3c add files (images) (#1926) 2024-07-24 21:26:01 +00:00
pablodanswer
d2584c773a slightly clearer description of model settings in assistants creation tab (#1925) 2024-07-24 21:25:30 +00:00
pablodanswer
807bef8ada Add environment variable for defaulted sidebar toggling (#1923)
* add env variable for defaulted sidebar toggling

* formatting

* update naming
2024-07-24 21:23:37 +00:00
rkuo-danswer
5afddacbb2 order list of new attempts from oldest to newest to prevent connector starvation (#1918) 2024-07-24 21:02:20 +00:00
hagen-danswer
4fb6a88f1e Quick fix (#1919) 2024-07-24 11:56:14 -07:00
rkuo-danswer
7057be6a88 Bugfix/indexing progress (#1916)
* mark in progress should always be committed

* no_commit version of mark_attempt is not needed
2024-07-24 11:39:44 -07:00
Yuhong Sun
91be8e7bfb Skip Null Docs (#1917) 2024-07-24 11:31:33 -07:00
Yuhong Sun
9651ea828b Handling Metadata by Vector and Keyword (#1909) 2024-07-24 11:05:56 -07:00
rkuo-danswer
6ee74bd0d1 fix pointers to various background tasks and scripts (#1914) 2024-07-24 10:12:51 -07:00
pablodanswer
48a0d29a5c Fix empty / reverted embeddings (#1910) 2024-07-23 22:41:31 -07:00
hagen-danswer
6ff8e6c0ea Improve eval pipeline qol (#1908) 2024-07-23 17:16:34 -07:00
Yuhong Sun
2470c68506 Don't rephrase first chat query (#1907) 2024-07-23 16:20:11 -07:00
hagen-danswer
866bc803b1 Implemented LLM disabling for api call (#1905) 2024-07-23 16:12:51 -07:00
pablodanswer
9c6084bd0d Embeddings- Clean up modal + "Important" call out (#1903) 2024-07-22 21:29:22 -07:00
hagen-danswer
a0b46c60c6 Switched eval api target back to oneshotqa (#1902) 2024-07-22 20:55:18 -07:00
pablodanswer
4029233df0 hide incomplete sources for non-admins (#1901) 2024-07-22 13:40:11 -07:00
hagen-danswer
6c88c0156c Added file upload retry logic (#1889) 2024-07-22 13:13:22 -07:00
pablodanswer
33332d08f2 fix citation title (#1900)
* fix citation title

* remove title function
2024-07-22 17:37:04 +00:00
hagen-danswer
17005fb705 switched default pruning behavior and removed some logging (#1898) 2024-07-22 17:36:26 +00:00
hagen-danswer
48a7fe80b1 Committed LLM updates to db (#1899) 2024-07-22 10:30:24 -07:00
pablodanswer
1276732409 Misc bug fixes (#1895) 2024-07-22 10:22:43 -07:00
Weves
f91b92a898 Make is_public default true for LLMProvider 2024-07-21 22:22:37 -07:00
Weves
6222f533be Update force delete script to handle user groups 2024-07-21 22:22:37 -07:00
hagen-danswer
1b49d17239 Added ability to control LLM access based on group (#1870)
* Added ability to control LLM access based on group

* completed relationship deletion

* cleaned up function

* added comments

* fixed frontend strings

* mypy fixes

* added case handling for deletion of user groups

* hidden advanced options now

* removed unnecessary code
2024-07-22 04:31:44 +00:00
Yuhong Sun
2f5f19642e Double Check Max Tokens for Indexing (#1893) 2024-07-21 21:12:39 -07:00
Yuhong Sun
6db4634871 Token Truncation (#1892) 2024-07-21 16:26:32 -07:00
Yuhong Sun
5cfed45cef Handle Empty Titles (#1891) 2024-07-21 14:59:23 -07:00
Weves
581ffde35a Fix jira connector failures for server deployments 2024-07-21 14:44:25 -07:00
pablodanswer
6313e6d91d Remove visit api when unneded (#1885)
* quick fix to test on ec2

* quick cleanup

* modify a name

* address full doc as well

* additional timing info + handling

* clean up

* squash

* Print only
2024-07-21 20:57:24 +00:00
Weves
c09c94bf32 Fix assistant swap 2024-07-21 13:57:36 -07:00
Yuhong Sun
0e8ba111c8 Model Touchups (#1887) 2024-07-21 12:31:00 -07:00
Yuhong Sun
2ba24b1734 Reenable Search Pipeline (#1886) 2024-07-21 10:33:29 -07:00
Yuhong Sun
44820b4909 k 2024-07-21 10:27:57 -07:00
hagen-danswer
eb3e7610fc Added retries and multithreading for cloud embedding (#1879)
* added retries and multithreading for cloud embedding

* refactored a bit

* cleaned up code

* got the errors to bubble up to the ui correctly

* added exceptin printing

* added requirements

* touchups

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-07-20 22:10:18 -07:00
pablodanswer
7fbbb174bb minor fixes (#1882)
- Assistants tab size
- Fixed logo -> absolute
2024-07-20 21:02:57 -07:00
pablodanswer
3854ca11af add newlines for message content 2024-07-20 18:57:29 -07:00
150 changed files with 2050 additions and 1620 deletions

View File

@@ -0,0 +1,41 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@@ -0,0 +1,27 @@
"""add expiry time
Revision ID: 91ffac7e65b3
Revises: bc9771dccadf
Create Date: 2024-06-24 09:39:56.462242
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "91ffac7e65b3"
down_revision = "795b20b85b4b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
)
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("user", "oidc_expiry")
# ### end Alembic commands ###

View File

@@ -1,6 +1,8 @@
import smtplib
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional
@@ -52,6 +54,7 @@ from danswer.db.auth import get_user_db
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@@ -92,12 +95,20 @@ def user_needs_to_be_verified() -> bool:
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def verify_email_in_whitelist(email: str) -> None:
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(
email: str,
db_session: Session = Depends(get_session),
) -> None:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
def verify_email_domain(email: str) -> None:
if VALID_EMAIL_DOMAINS:
if email.count("@") != 1:
@@ -147,7 +158,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
verify_email_in_whitelist(user_create.email)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
@@ -173,7 +184,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_in_whitelist(account_email)
verify_email_domain(account_email)
return await super().oauth_callback( # type: ignore
user = await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
@@ -185,6 +196,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
is_verified_by_default=is_verified_by_default,
)
# NOTE: google oauth expires after 1hr. We don't want to force the user to
# re-authenticate that frequently, so for now we'll just ignore this for
# google oauth users
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
@@ -227,10 +246,12 @@ cookie_transport = CookieTransport(
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
strategy = DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
auth_backend = AuthenticationBackend(
name="database",

View File

@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -80,7 +80,7 @@ def should_prune_cc_pair(
return True
return False
if PREVENT_SIMULTANEOUS_PRUNING:
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
@@ -89,11 +89,9 @@ def should_prune_cc_pair(
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
if not last_pruning_task.start_time:

View File

@@ -20,7 +20,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pai
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
@@ -299,9 +299,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress__no_commit(attempt)
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
db_session.commit()
mark_attempt_in_progress(attempt, db_session)
return attempt

View File

@@ -33,7 +33,7 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
@@ -271,6 +271,8 @@ def kickoff_indexing_jobs(
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
with Session(engine) as db_session:
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.embedding_model)
for attempt in get_not_started_index_attempts(db_session)

View File

@@ -51,7 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
@@ -187,37 +187,46 @@ def _handle_internet_search_tool_response_summary(
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
# If files are already provided, don't run the search tool
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
if new_msg_req.file_descriptors:
return None
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
if (
new_msg_req.query_override
or (
should_force_search = any(
[
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
)
or new_msg_req.search_doc_ids
or DISABLE_LLM_CHOOSE_SEARCH
):
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override
else None
)
# if we are using selected docs, just put something here so the Tool doesn't need
# to build its own args via an LLM call
if new_msg_req.search_doc_ids:
args = {"query": new_msg_req.message}
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
return ForceUseTool(
tool_name=SearchTool._NAME,
args=args,
)
return None
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
ChatPacket = (
@@ -253,7 +262,6 @@ def stream_chat_message_objects(
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
try:
user_id = user.id if user is not None else None
@@ -361,6 +369,14 @@ def stream_chat_message_objects(
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
@@ -576,11 +592,7 @@ def stream_chat_message_objects(
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=(
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
force_use_tool=_get_force_search_settings(new_msg_req, tools),
)
reference_db_search_docs = None

View File

@@ -214,8 +214,8 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.

View File

@@ -44,7 +44,6 @@ QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
# For chunking/processing chunks
MAX_CHUNK_TITLE_LEN = 1000
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
# For combining attributes, doesn't have to be unique/perfect to work

View File

@@ -12,17 +12,13 @@ import os
# The useable models configured as below must be SentenceTransformer compatible
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
DOCUMENT_ENCODER_MODEL = (
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
)
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
# If the below is changed, Vespa deployment must also be changed
DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 768)
# Model should be chosen with 512 context size, ideally don't change this
DOC_EMBEDDING_CONTEXT_SIZE = 512
NORMALIZE_EMBEDDINGS = (
os.environ.get("NORMALIZE_EMBEDDINGS") or "true"
).lower() == "true"
NORMALIZE_EMBEDDINGS = False
# Old default model settings, which are needed for an automatic easy upgrade
OLD_DEFAULT_DOCUMENT_ENCODER_MODEL = "thenlper/gte-small"
@@ -34,8 +30,8 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
ASYM_QUERY_PREFIX = "search_query: "
ASYM_PASSAGE_PREFIX = "search_document: "
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# For score display purposes, only way is to know the expected ranges

View File

@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
continue
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments]
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
pass
metadata_dict = {}
if jira.fields.priority:
metadata_dict["priority"] = jira.fields.priority.name
if jira.fields.status:
metadata_dict["status"] = jira.fields.status.name
if jira.fields.resolution:
metadata_dict["resolution"] = jira.fields.resolution.name
if jira.fields.labels:
metadata_dict["label"] = jira.fields.labels
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
doc_batch.append(
Document(

View File

@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=100, period=60)
@rate_limit_builder(max_calls=50, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")

View File

@@ -114,7 +114,9 @@ class DocumentBase(BaseModel):
title: str | None = None
from_ingestion_api: bool = False
def get_title_for_document_index(self) -> str | None:
def get_title_for_document_index(
self,
) -> str | None:
# If title is explicitly empty, return a None here for embedding purposes
if self.title == "":
return None

View File

@@ -15,6 +15,7 @@ from playwright.sync_api import BrowserContext
from playwright.sync_api import Playwright
from playwright.sync_api import sync_playwright
from requests_oauthlib import OAuth2Session # type:ignore
from urllib3.exceptions import MaxRetryError
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID
@@ -83,6 +84,13 @@ def check_internet_connection(url: str) -> None:
try:
response = requests.get(url, timeout=3)
response.raise_for_status()
except requests.exceptions.SSLError as e:
cause = (
e.args[0].reason
if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError)
else e.args
)
raise Exception(f"SSL error {str(cause)}")
except (requests.RequestException, ValueError):
raise Exception(f"Unable to reach {url} - check your internet connection")

View File

@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST

View File

@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)

View File

@@ -65,9 +65,12 @@ def get_inprogress_index_attempts(
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
"""This eagerly loads the connector and credential so that the db_session can be expired
before running long-living indexing jobs, which causes increasing memory usage"""
before running long-living indexing jobs, which causes increasing memory usage.
Results are ordered by time_created (oldest to newest)."""
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
stmt = stmt.order_by(IndexAttempt.time_created)
stmt = stmt.options(
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
)
@@ -75,11 +78,13 @@ def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
return list(new_attempts.all())
def mark_attempt_in_progress__no_commit(
def mark_attempt_in_progress(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
index_attempt.status = IndexingStatus.IN_PROGRESS
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
def mark_attempt_succeeded(

View File

@@ -1,15 +1,41 @@
from sqlalchemy import delete
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def update_group_llm_provider_relationships__no_commit(
llm_provider_id: int,
group_ids: list[int] | None,
db_session: Session,
) -> None:
# Delete existing relationships
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
).delete(synchronize_session="fetch")
# Add new relationships from given group_ids
if group_ids:
new_relationships = [
LLMProvider__UserGroup(
llm_provider_id=llm_provider_id,
user_group_id=group_id,
)
for group_id in group_ids
]
db_session.add_all(new_relationships)
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
@@ -36,36 +62,35 @@ def upsert_llm_provider(
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider:
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = (
llm_provider.fast_default_model_name
)
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
default_model_name=llm_provider.default_model_name,
fast_default_model_name=llm_provider.fast_default_model_name,
model_names=llm_provider.model_names,
is_default_provider=None,
if not existing_llm_provider:
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
group_ids=llm_provider.groups,
db_session=db_session,
)
db_session.add(llm_provider_model)
db_session.commit()
return FullLLMProvider.from_model(llm_provider_model)
return FullLLMProvider.from_model(existing_llm_provider)
def fetch_existing_embedding_providers(
@@ -74,8 +99,29 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_subquery = (
select(User__UserGroup.user_group_id)
.where(User__UserGroup.user_id == user.id)
.subquery()
)
access_conditions = or_(
LLMProviderModel.is_public,
LLMProviderModel.id.in_( # User is part of a group that has access
select(LLMProvider__UserGroup.llm_provider_id).where(
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
)
),
)
stmt = stmt.where(access_conditions)
return list(db_session.scalars(stmt).all())
def fetch_embedding_provider(
@@ -119,6 +165,13 @@ def remove_embedding_provider(
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships
db_session.execute(
delete(LLMProvider__UserGroup).where(
LLMProvider__UserGroup.llm_provider_id == provider_id
)
)
# Remove LLMProvider
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)

View File

@@ -11,6 +11,7 @@ from uuid import UUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import Enum
@@ -120,6 +121,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
postgresql.ARRAY(Integer), nullable=True
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"
@@ -932,6 +937,13 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="llm_provider__user_group",
viewonly=True,
)
class CloudEmbeddingProvider(Base):
@@ -1109,7 +1121,6 @@ class Persona(Base):
# where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# These are only defaults, users can select from all if desired
prompts: Mapped[list[Prompt]] = relationship(
@@ -1137,6 +1148,7 @@ class Persona(Base):
viewonly=True,
)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="persona__user_group",
@@ -1360,6 +1372,17 @@ class Persona__UserGroup(Base):
)
class LLMProvider__UserGroup(Base):
__tablename__ = "llm_provider__user_group"
llm_provider_id: Mapped[int] = mapped_column(
ForeignKey("llm_provider.id"), primary_key=True
)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
class DocumentSet__UserGroup(Base):
__tablename__ = "document_set__user_group"

View File

@@ -153,41 +153,43 @@ schema DANSWER_CHUNK_NAME {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
}
function title_vector_score() {
# This must be separate function for normalize_linear to work
function vector_score() {
expression {
# If no title, the full vector score comes from the content embedding
#query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
(query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))) +
((1 - query(title_content_ratio)) * closeness(field, embeddings))
}
}
# This must be separate function for normalize_linear to work
function keyword_score() {
expression {
(query(title_content_ratio) * bm25(title)) +
((1 - query(title_content_ratio)) * bm25(content))
}
}
first-phase {
expression: closeness(field, embeddings)
expression: vector_score
}
# Weighted average between Vector Search and BM-25
# Each is a weighted average between the Title and Content fields
# Finally each doc is boosted by it's user feedback based boost and recency
# If any embedding or index field is missing, it just receives a score of 0
# Assumptions:
# - For a given query + corpus, the BM-25 scores will be relatively similar in distribution
# therefore not normalizing before combining.
# - For documents without title, it gets a score of 0 for that and this is ok as documents
# without any title match should be penalized.
global-phase {
expression {
(
# Weighted Vector Similarity Score
(
query(alpha) * (
(query(title_content_ratio) * normalize_linear(title_vector_score))
+
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
)
)
+
(query(alpha) * normalize_linear(vector_score)) +
# Weighted Keyword Similarity Score
(
(1 - query(alpha)) * (
(query(title_content_ratio) * normalize_linear(bm25(title)))
+
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
)
)
((1 - query(alpha)) * normalize_linear(keyword_score))
)
# Boost based on user feedback
* document_boost
@@ -202,6 +204,8 @@ schema DANSWER_CHUNK_NAME {
bm25(content)
closeness(field, title_embedding)
closeness(field, embeddings)
keyword_score
vector_score
document_boost
recency_bias
closest(embeddings)

View File

@@ -331,12 +331,18 @@ def _index_vespa_chunk(
document = chunk.source_document
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
embeddings = chunk.embeddings
if chunk.embeddings.full_embedding is None:
embeddings.full_embedding = chunk.title_embedding
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
if embeddings.mini_chunk_embeddings:
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
if m_c_embed is None:
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
else:
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
title = document.get_title_for_document_index()
@@ -346,11 +352,15 @@ def _index_vespa_chunk(
BLURB: remove_invalid_unicode_chars(chunk.blurb),
TITLE: remove_invalid_unicode_chars(title) if title else None,
SKIP_TITLE_EMBEDDING: not title,
CONTENT: remove_invalid_unicode_chars(chunk.content),
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
# natural language representation of the metadata section
CONTENT: remove_invalid_unicode_chars(
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content
# which contains the title prefix and metadata suffix
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: json.dumps(chunk.source_links),
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
@@ -358,7 +368,7 @@ def _index_vespa_chunk(
METADATA: json.dumps(document.metadata),
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
METADATA_SUFFIX: chunk.metadata_suffix,
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
BOOST: chunk.boost,

View File

@@ -103,6 +103,7 @@ def port_api_key_to_postgres() -> None:
default_model_name=default_model_name,
fast_default_model_name=default_fast_model_name,
model_names=None,
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
update_default_provider(db_session, llm_provider.id)

View File

@@ -6,7 +6,6 @@ from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.constants import SECTION_SEPARATOR
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
@@ -15,12 +14,12 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
)
from danswer.connectors.models import Document
from danswer.indexing.models import DocAwareChunk
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import shared_precompare_cleanup
if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
from llama_index.text_splitter import SentenceSplitter # type:ignore
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
@@ -28,6 +27,8 @@ if TYPE_CHECKING:
CHUNK_OVERLAP = 0
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
# overwhelm the actual contents of the chunk
# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix
# could be another 128 tokens leaving 256 for the actual contents
MAX_METADATA_PERCENTAGE = 0.25
CHUNK_MIN_CONTENT = 256
@@ -36,14 +37,7 @@ logger = setup_logger()
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
def extract_blurb(text: str, blurb_size: int) -> str:
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize
blurb_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
)
def extract_blurb(text: str, blurb_splitter: "SentenceSplitter") -> str:
return blurb_splitter.split_text(text)[0]
@@ -52,33 +46,25 @@ def chunk_large_section(
section_link_text: str,
document: Document,
start_chunk_id: int,
tokenizer: "AutoTokenizer",
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
blurb: str,
chunk_splitter: "SentenceSplitter",
title_prefix: str = "",
metadata_suffix: str = "",
metadata_suffix_semantic: str = "",
metadata_suffix_keyword: str = "",
) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
blurb = extract_blurb(section_text, blurb_size)
sentence_aware_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
split_texts = sentence_aware_splitter.split_text(section_text)
split_texts = chunk_splitter.split_text(section_text)
chunks = [
DocAwareChunk(
source_document=document,
chunk_id=start_chunk_id + chunk_ind,
blurb=blurb,
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
content_summary=chunk_str,
content=chunk_str,
source_links={0: section_link_text},
section_continuation=(chunk_ind != 0),
metadata_suffix=metadata_suffix,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
for chunk_ind, chunk_str in enumerate(split_texts)
]
@@ -86,42 +72,87 @@ def chunk_large_section(
def _get_metadata_suffix_for_document_index(
metadata: dict[str, str | list[str]]
) -> str:
metadata: dict[str, str | list[str]], include_separator: bool = False
) -> tuple[str, str]:
"""
Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding
and a string of all of the values for the keyword search
For example, if we have the following metadata:
{
"author": "John Doe",
"space": "Engineering"
}
The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe
and Engineering. The keys are repeat and much more noisy.
"""
if not metadata:
return ""
return "", ""
metadata_str = "Metadata:\n"
metadata_values = []
for key, value in metadata.items():
if key in get_metadata_keys_to_ignore():
continue
value_str = ", ".join(value) if isinstance(value, list) else value
if isinstance(value, list):
metadata_values.extend(value)
else:
metadata_values.append(value)
metadata_str += f"\t{key} - {value_str}\n"
return metadata_str.strip()
metadata_semantic = metadata_str.strip()
metadata_keyword = " ".join(metadata_values)
if include_separator:
return RETURN_SEPARATOR + metadata_semantic, RETURN_SEPARATOR + metadata_keyword
return metadata_semantic, metadata_keyword
def chunk_document(
document: Document,
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
blurb_size: int = BLURB_SIZE, # Used for both title and content
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
tokenizer = get_default_tokenizer()
title = document.get_title_for_document_index()
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize, chunk_size=blurb_size, chunk_overlap=0
)
chunk_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=chunk_tok_size,
chunk_overlap=subsection_overlap,
)
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(tokenizer.tokenize(title_prefix))
metadata_suffix = ""
metadata_suffix_semantic = ""
metadata_suffix_keyword = ""
metadata_tokens = 0
if include_metadata:
metadata = _get_metadata_suffix_for_document_index(document.metadata)
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
(
metadata_suffix_semantic,
metadata_suffix_keyword,
) = _get_metadata_suffix_for_document_index(
document.metadata, include_separator=True
)
metadata_tokens = len(tokenizer.tokenize(metadata_suffix_semantic))
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
metadata_suffix = ""
# Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model
# context, there is no limit for the keyword component
metadata_suffix_semantic = ""
metadata_tokens = 0
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
@@ -130,7 +161,7 @@ def chunk_document(
if content_token_limit <= CHUNK_MIN_CONTENT:
content_token_limit = chunk_tok_size
title_prefix = ""
metadata_suffix = ""
metadata_suffix_semantic = ""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
@@ -151,12 +182,13 @@ def chunk_document(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
blurb=extract_blurb(chunk_text, blurb_splitter),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
)
link_offsets = {}
@@ -167,12 +199,11 @@ def chunk_document(
section_link_text=section_link_text,
document=document,
start_chunk_id=len(chunks),
tokenizer=tokenizer,
chunk_size=content_token_limit,
chunk_overlap=subsection_overlap,
blurb_size=blurb_size,
chunk_splitter=chunk_splitter,
blurb=extract_blurb(section_text, blurb_splitter),
title_prefix=title_prefix,
metadata_suffix=metadata_suffix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
chunks.extend(large_section_chunks)
continue
@@ -193,12 +224,13 @@ def chunk_document(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
blurb=extract_blurb(chunk_text, blurb_splitter),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
)
link_offsets = {0: section_link_text}
@@ -211,12 +243,13 @@ def chunk_document(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
blurb=extract_blurb(chunk_text, blurb_splitter),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
)
return chunks

View File

@@ -4,7 +4,6 @@ from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
@@ -14,8 +13,7 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.batching import batch_list
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
@@ -66,48 +64,38 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
)
def embed_chunks(
self,
chunks: list[DocAwareChunk],
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[IndexChunk]:
# Cache the Title embeddings to only have to do it once
title_embed_dict: dict[str, list[float]] = {}
title_embed_dict: dict[str, list[float] | None] = {}
embedded_chunks: list[IndexChunk] = []
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunk_texts = []
chunk_texts: list[str] = []
chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks):
chunk_texts.append(chunk.content)
# The whole chunk including the prefix/suffix is included in the overall vector representation
chunk_texts.append(
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
)
mini_chunk_texts = (
split_chunk_text_into_mini_chunks(chunk.content_summary)
split_chunk_text_into_mini_chunks(chunk.content)
if enable_mini_chunk
else []
)
chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
# Batching for embedding
text_batches = batch_list(chunk_texts, batch_size)
embeddings: list[list[float]] = []
len_text_batches = len(text_batches)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
)
# Replace line above with the line below for easy debugging of indexing flow
# skipping the actual model
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
embeddings = self.embedding_model.encode(
chunk_texts, text_type=EmbedTextType.PASSAGE
)
chunk_titles = {
chunk.source_document.get_title_for_document_index() for chunk in chunks
@@ -116,16 +104,15 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
# Drop any None or empty strings
chunk_titles_list = [title for title in chunk_titles if title]
# Embed Titles in batches
title_batches = batch_list(chunk_titles_list, batch_size)
len_title_batches = len(title_batches)
for ind_batch, title_batch in enumerate(title_batches, start=1):
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
if chunk_titles_list:
title_embeddings = self.embedding_model.encode(
title_batch, text_type=EmbedTextType.PASSAGE
chunk_titles_list, text_type=EmbedTextType.PASSAGE
)
title_embed_dict.update(
{title: vector for title, vector in zip(title_batch, title_embeddings)}
{
title: vector
for title, vector in zip(chunk_titles_list, title_embeddings)
}
)
# Mapping embeddings to chunks
@@ -184,4 +171,6 @@ def get_embedding_model_from_db_embedding_model(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
provider_type=db_embedding_model.provider_type,
api_key=db_embedding_model.api_key,
)

View File

@@ -124,6 +124,19 @@ def index_doc_batch(
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
# Skip documents that have neither title nor content
documents_to_process = []
for document in documents:
if not document.title and not any(
section.text.strip() for section in document.sections
):
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content"
)
else:
documents_to_process.append(document)
documents = documents_to_process
document_ids = [document.id for document in documents]
db_docs = get_documents_by_ids(
document_ids=document_ids,

View File

@@ -13,7 +13,7 @@ if TYPE_CHECKING:
logger = setup_logger()
Embedding = list[float]
Embedding = list[float] | None
class ChunkEmbedding(BaseModel):
@@ -36,15 +36,13 @@ class DocAwareChunk(BaseChunk):
# During inference we only have access to the document id and do not reconstruct the Document
source_document: Document
# The Vespa documents require a separate highlight field. Since it is stored as a duplicate anyway,
# it's easier to just store a not prefixed/suffixed string for the highlighting
# Also during the chunking, this non-prefixed/suffixed string is used for mini-chunks
content_summary: str
title_prefix: str
# During indexing we also (optionally) build a metadata string from the metadata dict
# This is also indexed so that we can strip it out after indexing, this way it supports
# multiple iterations of metadata representation for backwards compatibility
metadata_suffix: str
metadata_suffix_semantic: str
metadata_suffix_keyword: str
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk"""

View File

@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import message_generator_to_string_generator
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
@@ -99,6 +99,7 @@ class Answer:
answer_style_config: AnswerStyleConfig,
llm: LLM,
prompt_config: PromptConfig,
force_use_tool: ForceUseTool,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
@@ -107,10 +108,8 @@ class Answer:
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
tools: list[Tool] | None = None,
# if specified, tells the LLM to always this tool
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
# but we only support them anyways
force_use_tool: ForceUseTool | None = None,
# if set to True, then never use the LLMs provided tool-calling functonality
skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
@@ -129,6 +128,7 @@ class Answer:
self.tools = tools or []
self.force_use_tool = force_use_tool
self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or []
@@ -187,7 +187,7 @@ class Answer:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool and self.force_use_tool.args is not None:
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
@@ -221,7 +221,7 @@ class Answer:
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool else None,
tool_choice="required" if self.force_use_tool.force_use else None,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
@@ -240,12 +240,26 @@ class Answer:
# if we have a tool call, we need to call the tool
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
tool = [
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
][0]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool and self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
@@ -286,7 +300,7 @@ class Answer:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
if self.force_use_tool:
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
@@ -303,7 +317,7 @@ class Answer:
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,

View File

@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow

View File

@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection

View File

@@ -70,6 +70,7 @@ def load_llm_providers(db_session: Session) -> None:
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
),
model_names=model_names,
is_public=True,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
update_default_provider(db_session, llm_provider.id)

View File

@@ -1,6 +1,5 @@
from collections.abc import Callable
from collections.abc import Iterator
from copy import copy
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -16,10 +15,8 @@ from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from tiktoken.core import Encoding
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
@@ -28,7 +25,6 @@ from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from shared_configs.configs import LOG_LEVEL
@@ -37,53 +33,6 @@ if TYPE_CHECKING:
logger = setup_logger()
_LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
return _LLM_TOKENIZER
def get_default_llm_token_encode() -> Callable[[str], Any]:
global _LLM_TOKENIZER_ENCODE
if _LLM_TOKENIZER_ENCODE is None:
tokenizer = get_default_llm_tokenizer()
if isinstance(tokenizer, Encoding):
return tokenizer.encode # type: ignore
# Currently only supports OpenAI encoder
raise ValueError("Invalid Encoder selected")
return _LLM_TOKENIZER_ENCODE
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
) -> list[InferenceChunk]:
tokenizer = get_default_llm_tokenizer()
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],

View File

@@ -50,8 +50,8 @@ from danswer.db.standard_answer import create_initial_default_standard_answer_ca
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router

View File

@@ -1,14 +1,13 @@
import gc
import os
import time
from typing import Optional
from typing import TYPE_CHECKING
import requests
from transformers import logging as transformer_logging # type:ignore
from httpx import HTTPError
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.natural_language_processing.utils import get_default_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
@@ -20,50 +19,13 @@ from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
logger = setup_logger()
if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
return _TOKENIZER[0]
def build_model_server_url(
model_server_host: str,
model_server_port: int,
@@ -91,6 +53,7 @@ class EmbeddingModel:
provider_type: str | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
retrim_content: bool = False,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -99,32 +62,90 @@ class EmbeddingModel:
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
self.retrim_content = retrim_content
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
if text_type == EmbedTextType.QUERY and self.query_prefix:
prefixed_texts = [self.query_prefix + text for text in texts]
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
prefixed_texts = [self.passage_prefix + text for text in texts]
else:
prefixed_texts = texts
def encode(
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
) -> list[list[float] | None]:
if not texts:
logger.warning("No texts to be embedded")
return []
embed_request = EmbedRequest(
model_name=self.model_name,
texts=prefixed_texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
)
if self.retrim_content:
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
# However this slight miscounting is very unlikely to have any material impact.
texts = [
tokenizer_trim_content(
content=text,
desired_length=self.max_seq_length,
tokenizer=get_default_tokenizer(),
)
for text in texts
]
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
response.raise_for_status()
if self.provider_type:
embed_request = EmbedRequest(
model_name=self.model_name,
texts=texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
EmbedResponse(**response.json()).embeddings
return EmbedResponse(**response.json()).embeddings
return EmbedResponse(**response.json()).embeddings
# Batching for local embedding
text_batches = batch_list(texts, batch_size)
embeddings: list[list[float] | None] = []
for idx, text_batch in enumerate(text_batches, start=1):
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(EmbedResponse(**response.json()).embeddings)
return embeddings
class CrossEncoderEnsembleModel:
@@ -136,7 +157,7 @@ class CrossEncoderEnsembleModel:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
rerank_request = RerankRequest(query=query, documents=passages)
response = requests.post(
@@ -199,7 +220,7 @@ def warm_up_encoders(
# First time downloading the models it may take even longer, but just in case,
# retry the whole server
wait_time = 5
for attempt in range(20):
for _ in range(20):
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
return

View File

@@ -0,0 +1,100 @@
import gc
import os
from collections.abc import Callable
from copy import copy
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
import tiktoken
from tiktoken.core import Encoding
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
_LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
return _TOKENIZER[0]
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
return _LLM_TOKENIZER
def get_default_llm_token_encode() -> Callable[[str], Any]:
global _LLM_TOKENIZER_ENCODE
if _LLM_TOKENIZER_ENCODE is None:
tokenizer = get_default_llm_tokenizer()
if isinstance(tokenizer, Encoding):
return tokenizer.encode # type: ignore
# Currently only supports OpenAI encoder
raise ValueError("Invalid Encoder selected")
return _LLM_TOKENIZER_ENCODE
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
) -> list[InferenceChunk]:
tokenizer = get_default_llm_tokenizer()
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks

View File

@@ -34,7 +34,7 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
@@ -206,6 +206,7 @@ def stream_answer_objects(
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
),
@@ -256,6 +257,9 @@ def stream_answer_objects(
)
yield initial_response
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
@@ -268,9 +272,6 @@ def stream_answer_objects(
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
elif packet.id == SEARCH_EVALUATION_ID:
evaluation_response = LLMRelevanceSummaryResponse(
relevance_summaries=packet.response

View File

@@ -2,7 +2,7 @@ from collections.abc import Callable
from collections.abc import Generator
from danswer.configs.constants import MessageType
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import ThreadMessage
from danswer.utils.logger import setup_logger

View File

@@ -34,6 +34,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
logger = setup_logger()
@@ -154,6 +155,7 @@ class SearchPipeline:
return cast(list[InferenceChunk], self._retrieved_chunks)
@log_function_time(print_only=True)
def _get_sections(self) -> list[InferenceSection]:
"""Returns an expanded section from each of the chunks.
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
@@ -173,9 +175,11 @@ class SearchPipeline:
expanded_inference_sections = []
# Full doc setting takes priority
if self.search_query.full_doc:
seen_document_ids = set()
unique_chunks = []
# This preserves the ordering since the chunks are retrieved in score order
for chunk in retrieved_chunks:
if chunk.document_id not in seen_document_ids:
@@ -195,7 +199,6 @@ class SearchPipeline:
),
)
)
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
@@ -240,32 +243,35 @@ class SearchPipeline:
merged_ranges = [
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
]
flat_ranges = [r for ranges in merged_ranges for r in ranges]
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
flattened_inference_chunks: list[InferenceChunk] = []
parallel_functions_with_args = []
for chunk_range in flat_ranges:
functions_with_args.append(
(
# If Large Chunks are introduced, additional filters need to be added here
self.document_index.id_based_retrieval,
(
# Only need the document_id here, just use any chunk in the range is fine
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
# There is no chunk level permissioning, this expansion around chunks
# can be assumed to be safe
IndexFilters(access_control_list=None),
),
)
)
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
if above == below == 0:
flattened_inference_chunks.extend(chunk_range.chunks)
# list of list of inference chunks where the inner list needs to be combined for content
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
flattened_inference_chunks = [
chunk for sublist in list_inference_chunks for chunk in sublist
]
else:
parallel_functions_with_args.append(
(
self.document_index.id_based_retrieval,
(
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
IndexFilters(access_control_list=None),
),
)
)
if parallel_functions_with_args:
list_inference_chunks = run_functions_tuples_in_parallel(
parallel_functions_with_args, allow_failures=False
)
for inference_chunks in list_inference_chunks:
flattened_inference_chunks.extend(inference_chunks)
doc_chunk_ind_to_chunk = {
(chunk.document_id, chunk.chunk_id): chunk

View File

@@ -4,7 +4,7 @@ from typing import cast
import numpy
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
@@ -12,6 +12,9 @@ from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import (
CrossEncoderEnsembleModel,
)
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
@@ -20,7 +23,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
@@ -58,8 +60,14 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
if chunk.content.startswith(chunk.title):
return chunk.content[len(chunk.title) :].lstrip()
if chunk.content.startswith(chunk.title[:MAX_CHUNK_TITLE_LEN]):
return chunk.content[MAX_CHUNK_TITLE_LEN:].lstrip()
# BLURB SIZE is by token instead of char but each token is at least 1 char
# If this prefix matches the content, it's assumed the title was prepended
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
return (
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
if RETURN_SEPARATOR in chunk.content
else chunk.content
)
return chunk.content

View File

@@ -1,10 +1,10 @@
from typing import TYPE_CHECKING
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
from danswer.natural_language_processing.search_nlp_models import IntentModel
from danswer.search.enums import QueryFlow
from danswer.search.models import SearchType
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.search.search_nlp_models import IntentModel
from danswer.server.query_and_chat.models import HelperResponse
from danswer.utils.logger import setup_logger

View File

@@ -1,5 +1,6 @@
import string
from collections.abc import Callable
from typing import cast
import nltk # type:ignore
from nltk.corpus import stopwords # type:ignore
@@ -11,6 +12,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
@@ -20,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
@@ -143,7 +144,9 @@ def doc_index_retrieval(
if query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query=query.query,
query_embedding=query_embedding,
query_embedding=cast(
list[float], query_embedding
), # query embeddings should always have vector representations
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
@@ -152,7 +155,9 @@ def doc_index_retrieval(
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
query_embedding=cast(
list[float], query_embedding
), # query embeddings should always have vector representations
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,

View File

@@ -94,7 +94,7 @@ def history_based_query_rephrase(
llm: LLM,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
skip_first_rephrase: bool = False,
skip_first_rephrase: bool = True,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> str:
# Globally disabled, just use the exact user query

View File

@@ -96,6 +96,8 @@ def upsert_ingestion_doc(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
)
indexing_pipeline = build_indexing_pipeline(
@@ -132,6 +134,8 @@ def upsert_ingestion_doc(
normalize=sec_db_embedding_model.normalize,
query_prefix=sec_db_embedding_model.query_prefix,
passage_prefix=sec_db_embedding_model.passage_prefix,
api_key=sec_db_embedding_model.api_key,
provider_type=sec_db_embedding_model.provider_type,
)
sec_ind_pipeline = build_indexing_pipeline(

View File

@@ -9,7 +9,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.utils import get_default_llm_token_encode
from danswer.natural_language_processing.utils import get_default_llm_token_encode
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.server.documents.models import ChunkInfo

View File

@@ -10,7 +10,7 @@ from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.embedding.models import TestEmbeddingRequest
@@ -42,7 +42,7 @@ def test_embedding_configuration(
passage_prefix=None,
model_name=None,
)
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
except ValueError as e:
error_msg = f"Not a valid embedding model. Exception thrown: {e}"

View File

@@ -147,10 +147,10 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
_: User | None = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [
LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
]

View File

@@ -60,6 +60,8 @@ class LLMProvider(BaseModel):
custom_config: dict[str, str] | None
default_model_name: str
fast_default_model_name: str | None
is_public: bool = True
groups: list[int] | None = None
class LLMProviderUpsertRequest(LLMProvider):
@@ -91,4 +93,6 @@ class FullLLMProvider(LLMProvider):
or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name]
),
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
)

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -14,13 +15,15 @@ from danswer.db.models import SlackBotConfig as SlackBotConfigModel
from danswer.db.models import SlackBotResponseType
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
from danswer.db.models import User
from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
if TYPE_CHECKING:
from danswer.db.models import User as UserModel
pass
class VersionResponse(BaseModel):
@@ -46,9 +49,17 @@ class UserInfo(BaseModel):
is_verified: bool
role: UserRole
preferences: UserPreferences
oidc_expiry: datetime | None = None
current_token_created_at: datetime | None = None
current_token_expiry_length: int | None = None
@classmethod
def from_model(cls, user: "UserModel") -> "UserInfo":
def from_model(
cls,
user: User,
current_token_created_at: datetime | None = None,
expiry_length: int | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
email=user.email,
@@ -57,6 +68,9 @@ class UserInfo(BaseModel):
is_verified=user.is_verified,
role=user.role,
preferences=(UserPreferences(chosen_assistants=user.chosen_assistants)),
oidc_expiry=user.oidc_expiry,
current_token_created_at=current_token_created_at,
current_token_expiry_length=expiry_length,
)
@@ -151,7 +165,9 @@ class SlackBotConfigCreationRequest(BaseModel):
# by an optional `PersonaSnapshot` object. Keeping it like this
# for now for simplicity / speed of development
document_sets: list[int] | None
persona_id: int | None # NOTE: only one of `document_sets` / `persona_id` should be set
persona_id: (
int | None
) # NOTE: only one of `document_sets` / `persona_id` should be set
channel_names: list[str]
respond_tag_only: bool = False
respond_to_bots: bool = False

View File

@@ -1,4 +1,5 @@
import re
from datetime import datetime
from fastapi import APIRouter
from fastapi import Body
@@ -6,6 +7,9 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import status
from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
@@ -19,9 +23,11 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.db.users import list_users
@@ -117,9 +123,9 @@ def list_all_users(
id=user.id,
email=user.email,
role=user.role,
status=UserStatus.LIVE
if user.is_active
else UserStatus.DEACTIVATED,
status=(
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
),
)
for user in users
],
@@ -246,9 +252,35 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
return UserRoleResponse(role=user.role)
def get_current_token_creation(
user: User | None, db_session: Session
) -> datetime | None:
if user is None:
return None
try:
result = db_session.execute(
select(AccessToken)
.where(AccessToken.user_id == user.id) # type: ignore
.order_by(desc(Column("created_at")))
.limit(1)
)
access_token = result.scalar_one_or_none()
if access_token:
return access_token.created_at
else:
logger.error("No AccessToken found for user")
return None
except Exception as e:
logger.error(f"Error fetching AccessToken: {e}")
return None
@router.get("/me")
def verify_user_logged_in(
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> UserInfo:
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
# to enforce user verification here - the frontend always wants to get the info about
@@ -264,7 +296,14 @@ def verify_user_logged_in(
status_code=status.HTTP_403_FORBIDDEN, detail="User Not Authenticated"
)
return UserInfo.from_model(user)
token_created_at = get_current_token_creation(user, db_session)
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
)
return user_info
"""APIs to adjust user preferences"""

View File

@@ -45,7 +45,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
)

View File

@@ -90,7 +90,7 @@ class CreateChatMessageRequest(ChunkContext):
parent_message_id: int | None
# New message contents
message: str
# file's that we should attach to this message
# Files that we should attach to this message
file_descriptors: list[FileDescriptor]
# If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred

View File

@@ -1,13 +1,15 @@
from typing import Any
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from danswer.tools.tool import Tool
class ForceUseTool(BaseModel):
# Could be not a forced usage of the tool but still have args, in which case
# if the tool is called, then those args are applied instead of what the LLM
# wanted to call it with
force_use: bool
tool_name: str
args: dict[str, Any] | None = None
@@ -16,25 +18,10 @@ class ForceUseTool(BaseModel):
return {"type": "function", "function": {"name": self.tool_name}}
def modify_message_chain_for_force_use_tool(
messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None
) -> list[BaseMessage]:
"""NOTE: modifies `messages` in place."""
if not force_use_tool:
return messages
for message in messages:
if isinstance(message, AIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
tool_call["args"] = force_use_tool.args or {}
return messages
def filter_tools_for_force_tool_use(
tools: list[Tool], force_use_tool: ForceUseTool | None = None
tools: list[Tool], force_use_tool: ForceUseTool
) -> list[Tool]:
if not force_use_tool:
if not force_use_tool.force_use:
return tools
return [tool for tool in tools if tool.name == force_use_tool.tool_name]

View File

@@ -6,7 +6,7 @@ from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import ToolMessage
from pydantic import BaseModel
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
def build_tool_message(

View File

@@ -2,7 +2,7 @@ import json
from tiktoken import Encoding
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.tools.tool import Tool

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
@@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
db_session.delete(user__user_group_relationship)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
@@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)

View File

@@ -10,6 +10,7 @@ from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
@@ -40,110 +41,133 @@ router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
# If we are not only indexing, dont want retry very long
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
) -> Any:
if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key)
elif provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=api_key)
elif provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=api_key)
elif provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(api_key)
)
project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
else:
raise ValueError(f"Unsupported provider: {provider}")
class CloudEmbedding:
def __init__(self, api_key: str, provider: str, model: str | None = None):
self.api_key = api_key
def __init__(
self,
api_key: str,
provider: str,
# Only for Google as is needed on client setup
self.model = model
model: str | None = None,
) -> None:
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.client = self._initialize_client()
self.client = _initialize_client(api_key, self.provider, model)
def _initialize_client(self) -> Any:
if self.provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=self.api_key)
elif self.provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=self.api_key)
elif self.provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=self.api_key)
elif self.provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(
self.model or DEFAULT_VERTEX_MODEL
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def encode(
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
) -> list[list[float]]:
return [
self.embed(text=text, text_type=text_type, model=model_name)
for text in texts
]
def embed(
self, *, text: str, text_type: EmbedTextType, model: str | None = None
) -> list[float]:
logger.debug(f"Embedding text with provider: {self.provider}")
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _embed_openai(self, text: str, model: str | None) -> list[float]:
def _embed_openai(
self, texts: list[str], model: str | None
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_OPENAI_MODEL
response = self.client.embeddings.create(input=text, model=model)
return response.data[0].embedding
# OpenAI does not seem to provide truncation option, however
# the context lengths used by Danswer currently are smaller than the max token length
# for OpenAI embeddings so it's not a big deal
response = self.client.embeddings.create(input=texts, model=model)
return [embedding.embedding for embedding in response.data]
def _embed_cohere(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_COHERE_MODEL
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = self.client.embed(
texts=[text],
texts=texts,
model=model,
input_type=embedding_type,
truncate="END",
)
return response.embeddings[0]
return response.embeddings
def _embed_voyage(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
response = self.client.embed(text, model=model, input_type=embedding_type)
return response.embeddings[0]
# Similar to Cohere, the API server will do approximate size chunking
# it's acceptable to miss by a few tokens
response = self.client.embed(
texts,
model=model,
input_type=embedding_type,
truncation=True, # Also this is default
)
return response.embeddings
def _embed_vertex(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
if model is None:
model = DEFAULT_VERTEX_MODEL
embedding = self.client.get_embeddings(
embeddings = self.client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
]
for text in texts
],
auto_truncate=True, # Also this is default
)
return embedding[0].values
return [embedding.values for embedding in embeddings]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed(
self,
*,
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
) -> list[list[float] | None]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
@staticmethod
def create(
@@ -166,7 +190,7 @@ def get_embedding_model(
if model_name not in _GLOBAL_MODELS_DICT:
logger.info(f"Loading {model_name}")
model = SentenceTransformer(model_name)
model = SentenceTransformer(model_name, trust_remote_code=True)
model.max_seq_length = max_context_length
_GLOBAL_MODELS_DICT[model_name] = model
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
@@ -212,34 +236,83 @@ def embed_text(
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
) -> list[list[float]]:
if provider_type is not None:
prefix: str | None,
) -> list[list[float] | None]:
non_empty_texts = []
empty_indices = []
for idx, text in enumerate(texts):
if text.strip():
non_empty_texts.append(text)
else:
empty_indices.append(idx)
# Third party API based embedding model
if not non_empty_texts:
embeddings = []
elif provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
raise RuntimeError("API key not provided for cloud model")
if prefix:
# This may change in the future if some providers require the user
# to manually append a prefix but this is not the case currently
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.encode(texts, model_name, text_type)
embeddings = cloud_model.embed(
texts=non_empty_texts,
model_name=model_name,
text_type=text_type,
)
elif model_name is not None:
hosted_model = get_embedding_model(
prefixed_texts = (
[f"{prefix}{text}" for text in non_empty_texts]
if prefix
else non_empty_texts
)
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = hosted_model.encode(
texts, normalize_embeddings=normalize_embeddings
embeddings = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
)
else:
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
if embeddings is None:
raise RuntimeError("Embeddings were not created")
raise RuntimeError("Failed to create Embeddings")
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
embeddings_with_nulls: list[list[float] | None] = []
current_embedding_index = 0
for idx in range(len(texts)):
if idx in empty_indices:
embeddings_with_nulls.append(None)
else:
embedding = embeddings[current_embedding_index]
if isinstance(embedding, list) or embedding is None:
embeddings_with_nulls.append(embedding)
else:
embeddings_with_nulls.append(embedding.tolist())
current_embedding_index += 1
embeddings = embeddings_with_nulls
return embeddings
@simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
@@ -252,7 +325,17 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
async def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
try:
if embed_request.text_type == EmbedTextType.QUERY:
prefix = embed_request.manual_query_prefix
elif embed_request.text_type == EmbedTextType.PASSAGE:
prefix = embed_request.manual_passage_prefix
else:
prefix = None
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
@@ -261,13 +344,13 @@ async def process_embed_request(
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
logger.exception(f"Error during embedding process:\n{str(e)}")
raise HTTPException(
status_code=500, detail="Failed to run Bi-Encoder embedding"
)
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
@router.post("/cross-encoder-scores")
@@ -276,6 +359,11 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
if not embed_request.documents or not embed_request.query:
raise HTTPException(
status_code=400, detail="No documents or query to be reranked"
)
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents

View File

@@ -1,6 +1,8 @@
einops==0.8.0
fastapi==0.109.2
h5py==3.9.0
pydantic==1.10.13
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
tensorflow==2.15.0
@@ -9,5 +11,5 @@ transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
openai==1.14.3
cohere==5.5.8
google-cloud-aiplatform==1.58.0
cohere==5.6.1
google-cloud-aiplatform==1.58.0

View File

@@ -21,18 +21,17 @@ def run_jobs(exclude_indexing: bool) -> None:
cmd_worker = [
"celery",
"-A",
"ee.danswer.background.celery",
"ee.danswer.background.celery.celery_app",
"worker",
"--pool=threads",
"--autoscale=3,10",
"--concurrency=16",
"--loglevel=INFO",
"--concurrency=1",
]
cmd_beat = [
"celery",
"-A",
"ee.danswer.background.celery",
"ee.danswer.background.celery.celery_app",
"beat",
"--loglevel=INFO",
]
@@ -74,7 +73,7 @@ def run_jobs(exclude_indexing: bool) -> None:
try:
update_env = os.environ.copy()
update_env["PYTHONPATH"] = "."
cmd_perm_sync = ["python", "ee.danswer/background/permission_sync.py"]
cmd_perm_sync = ["python", "ee/danswer/background/permission_sync.py"]
indexing_process = subprocess.Popen(
cmd_perm_sync,

View File

@@ -14,7 +14,10 @@ sys.path.append(parent_dir)
# flake8: noqa: E402
# Now import Danswer modules
from danswer.db.models import DocumentSet__ConnectorCredentialPair
from danswer.db.models import (
DocumentSet__ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair,
)
from danswer.db.connector import fetch_connector_by_id
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.index_attempt import (
@@ -44,7 +47,7 @@ logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def unsafe_deletion(
def _unsafe_deletion(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
@@ -82,11 +85,22 @@ def unsafe_deletion(
credential_id=credential_id,
)
# Delete document sets + connector / credential Pairs
# Delete document sets
stmt = delete(DocumentSet__ConnectorCredentialPair).where(
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == pair_id
)
db_session.execute(stmt)
# delete user group associations
stmt = delete(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.cc_pair_id == pair_id
)
db_session.execute(stmt)
# need to flush to avoid foreign key violations
db_session.flush()
# delete the actual connector credential pair
stmt = delete(ConnectorCredentialPair).where(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
@@ -168,7 +182,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
files_deleted_count = unsafe_deletion(
files_deleted_count = _unsafe_deletion(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,

View File

@@ -4,9 +4,7 @@ from shared_configs.enums import EmbedTextType
class EmbedRequest(BaseModel):
# This already includes any prefixes, the text is just passed directly to the model
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None
max_context_length: int
@@ -14,10 +12,12 @@ class EmbedRequest(BaseModel):
api_key: str | None
provider_type: str | None
text_type: EmbedTextType
manual_query_prefix: str | None
manual_passage_prefix: str | None
class EmbedResponse(BaseModel):
embeddings: list[list[float]]
embeddings: list[list[float] | None]
class RerankRequest(BaseModel):
@@ -26,7 +26,7 @@ class RerankRequest(BaseModel):
class RerankResponse(BaseModel):
scores: list[list[float]]
scores: list[list[float] | None]
class IntentRequest(BaseModel):

View File

@@ -25,7 +25,7 @@ autorestart=true
# relatively compute-light (e.g. they tend to just make a bunch of requests to
# Vespa / Postgres)
[program:celery_worker]
command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --autoscale=3,10 --loglevel=INFO --logfile=/var/log/celery_worker.log
command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --concurrency=16 --loglevel=INFO --logfile=/var/log/celery_worker.log
stdout_logfile=/var/log/celery_worker_supervisor.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true

View File

@@ -9,7 +9,7 @@ This Python script automates the process of running search quality tests for a b
- Manages environment variables
- Switches to specified Git branch
- Uploads test documents
- Runs search quality tests using Relari
- Runs search quality tests
- Cleans up Docker containers (optional)
## Usage
@@ -29,9 +29,17 @@ export PYTHONPATH=$PYTHONPATH:$PWD/backend
```
cd backend/tests/regression/answer_quality
```
7. Run the script:
7. To launch the evaluation environment, run the launch_eval_env.py script (this step can be skipped if you are running the env outside of docker, just leave "environment_name" blank):
```
python run_eval_pipeline.py
python launch_eval_env.py
```
8. Run the file_uploader.py script to upload the zip files located at the path "zipped_documents_file"
```
python file_uploader.py
```
9. Run the run_qa.py script to ask questions from the jsonl located at the path "questions_file". This will hit the "query/answer-with-quote" API endpoint.
```
python run_qa.py
```
Note: All data will be saved even after the containers are shut down. There are instructions below to re-launching docker containers using this data.
@@ -61,6 +69,11 @@ Edit `search_test_config.yaml` to set:
- Set this to true to automatically delete all docker containers, networks and volumes after the test
- launch_web_ui
- Set this to true if you want to use the UI during/after the testing process
- only_state
- Whether to only run Vespa and Postgres
- only_retrieve_docs
- Set true to only retrieve documents, not LLM response
- This is to save on API costs
- use_cloud_gpu
- Set to true or false depending on if you want to use the remote gpu
- Only need to set this if use_cloud_gpu is true
@@ -70,12 +83,10 @@ Edit `search_test_config.yaml` to set:
- model_server_port
- This is the port of the remote model server
- Only need to set this if use_cloud_gpu is true
- existing_test_suffix (THIS IS NOT A SUFFIX ANYMORE, TODO UPDATE THE DOCS HERE)
- environment_name
- Use this if you would like to relaunch a previous test instance
- Input the suffix of the test you'd like to re-launch
- (E.g. to use the data from folder "test-1234-5678" put "-1234-5678")
- No new files will automatically be uploaded
- Leave empty to run a new test
- Input the env_name of the test you'd like to re-launch
- Leave empty to launch referencing local default network locations
- limit
- Max number of questions you'd like to ask against the dataset
- Set to null for no limit
@@ -85,7 +96,7 @@ Edit `search_test_config.yaml` to set:
## Relaunching From Existing Data
To launch an existing set of containers that has already completed indexing, set the existing_test_suffix variable. This will launch the docker containers mounted on the volumes of the indicated suffix and will not automatically index any documents or run any QA.
To launch an existing set of containers that has already completed indexing, set the environment_name variable. This will launch the docker containers mounted on the volumes of the indicated env_name and will not automatically index any documents or run any QA.
Once these containers are launched you can run file_uploader.py or run_qa.py (assuming you have run the steps in the Usage section above).
- file_uploader.py will upload and index additional zipped files located at the zipped_documents_file path.

View File

@@ -2,45 +2,31 @@ import requests
from retry import retry
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.connectors.models import InputType
from danswer.db.enums import IndexingStatus
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.models import IndexFilters
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.server.documents.models import ConnectorBase
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
GENERAL_HEADERS = {"Content-Type": "application/json"}
def _api_url_builder(run_suffix: str, api_path: str) -> str:
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
def _create_new_chat_session(run_suffix: str) -> int:
create_chat_request = ChatSessionCreationRequest(
persona_id=0,
description=None,
)
body = create_chat_request.dict()
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
response_json = requests.post(
create_chat_url, headers=GENERAL_HEADERS, json=body
).json()
chat_session_id = response_json.get("chat_session_id")
if isinstance(chat_session_id, int):
return chat_session_id
def _api_url_builder(env_name: str, api_path: str) -> str:
if env_name:
return f"http://localhost:{get_api_server_host_port(env_name)}" + api_path
else:
raise RuntimeError(response_json)
return "http://localhost:8080" + api_path
@retry(tries=10, delay=10)
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
@retry(tries=5, delay=5)
def get_answer_from_query(
query: str, only_retrieve_docs: bool, env_name: str
) -> tuple[list[str], str]:
filters = IndexFilters(
source_type=None,
document_set=None,
@@ -48,42 +34,47 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
tags=None,
access_control_list=None,
)
retrieval_options = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
new_message_request = DirectQARequest(
messages=messages,
prompt_id=0,
persona_id=0,
retrieval_options=RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
),
chain_of_thought=False,
return_contexts=True,
skip_gen_ai_answer_generation=only_retrieve_docs,
)
chat_session_id = _create_new_chat_session(run_suffix)
url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
new_message_request = BasicCreateChatMessageRequest(
chat_session_id=chat_session_id,
message=query,
retrieval_options=retrieval_options,
query_override=query,
)
url = _api_url_builder(env_name, "/query/answer-with-quote/")
headers = {
"Content-Type": "application/json",
}
body = new_message_request.dict()
body["user"] = None
try:
response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
simple_search_docs = response_json.get("simple_search_docs", [])
answer = response_json.get("answer", "")
response_json = requests.post(url, headers=headers, json=body).json()
context_data_list = response_json.get("contexts", {}).get("contexts", [])
answer = response_json.get("answer", "") or ""
except Exception as e:
print("Failed to answer the questions:")
print(f"\t {str(e)}")
print("trying again")
print("Try restarting vespa container and trying agian")
raise e
return simple_search_docs, answer
return context_data_list, answer
@retry(tries=10, delay=10)
def check_if_query_ready(run_suffix: str) -> bool:
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
def check_indexing_status(env_name: str) -> tuple[int, bool]:
url = _api_url_builder(env_name, "/manage/admin/connector/indexing-status/")
try:
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
except Exception as e:
@@ -98,20 +89,21 @@ def check_if_query_ready(run_suffix: str) -> bool:
status = index_attempt["last_status"]
if status == IndexingStatus.IN_PROGRESS or status == IndexingStatus.NOT_STARTED:
ongoing_index_attempts = True
elif status == IndexingStatus.SUCCESS:
doc_count += 16
doc_count += index_attempt["docs_indexed"]
doc_count -= 16
if not doc_count:
print("No docs indexed, waiting for indexing to start")
elif ongoing_index_attempts:
print(
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
)
return doc_count > 0 and not ongoing_index_attempts
# all the +16 and -16 are to account for the fact that the indexing status
# is only updated every 16 documents and will tells us how many are
# chunked, not indexed. probably need to fix this. in the future!
if doc_count:
doc_count += 16
return doc_count, ongoing_index_attempts
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
def run_cc_once(env_name: str, connector_id: int, credential_id: int) -> None:
url = _api_url_builder(env_name, "/manage/admin/connector/run-once/")
body = {
"connector_id": connector_id,
"credential_ids": [credential_id],
@@ -126,9 +118,9 @@ def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
print("Failed text:", response.text)
def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> None:
def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None:
url = _api_url_builder(
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
env_name, f"/manage/connector/{connector_id}/credential/{credential_id}"
)
body = {"name": "zip_folder_contents", "is_public": True}
@@ -141,8 +133,8 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
print("Failed text:", response.text)
def _get_existing_connector_names(run_suffix: str) -> list[str]:
url = _api_url_builder(run_suffix, "/manage/connector")
def _get_existing_connector_names(env_name: str) -> list[str]:
url = _api_url_builder(env_name, "/manage/connector")
body = {
"credential_json": {},
@@ -156,10 +148,10 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
raise RuntimeError(response.__dict__)
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
url = _api_url_builder(run_suffix, "/manage/admin/connector")
def create_connector(env_name: str, file_paths: list[str]) -> int:
url = _api_url_builder(env_name, "/manage/admin/connector")
connector_name = base_connector_name = "search_eval_connector"
existing_connector_names = _get_existing_connector_names(run_suffix)
existing_connector_names = _get_existing_connector_names(env_name)
count = 1
while connector_name in existing_connector_names:
@@ -186,8 +178,8 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
raise RuntimeError(response.__dict__)
def create_credential(run_suffix: str) -> int:
url = _api_url_builder(run_suffix, "/manage/credential")
def create_credential(env_name: str) -> int:
url = _api_url_builder(env_name, "/manage/credential")
body = {
"credential_json": {},
"admin_public": True,
@@ -201,12 +193,12 @@ def create_credential(run_suffix: str) -> int:
@retry(tries=10, delay=2, backoff=2)
def upload_file(run_suffix: str, zip_file_path: str) -> list[str]:
def upload_file(env_name: str, zip_file_path: str) -> list[str]:
files = [
("files", open(zip_file_path, "rb")),
]
api_path = _api_url_builder(run_suffix, "/manage/admin/connector/file/upload")
api_path = _api_url_builder(env_name, "/manage/admin/connector/file/upload")
try:
response = requests.post(api_path, files=files)
response.raise_for_status() # Raises an HTTPError for bad responses

View File

@@ -67,20 +67,20 @@ def switch_to_commit(commit_sha: str) -> None:
print("Repository updated successfully.")
def get_docker_container_env_vars(suffix: str) -> dict:
def get_docker_container_env_vars(env_name: str) -> dict:
"""
Retrieves environment variables from "background" and "api_server" Docker containers.
"""
print(f"Getting environment variables for containers with suffix: {suffix}")
print(f"Getting environment variables for containers with env_name: {env_name}")
combined_env_vars = {}
for container_type in ["background", "api_server"]:
container_name = _run_command(
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{suffix}/'"
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{env_name}/'"
)[0].strip()
if not container_name:
raise RuntimeError(
f"No {container_type} container found with suffix: {suffix}"
f"No {container_type} container found with env_name: {env_name}"
)
env_vars_json = _run_command(
@@ -95,9 +95,9 @@ def get_docker_container_env_vars(suffix: str) -> dict:
return combined_env_vars
def manage_data_directories(suffix: str, base_path: str, use_cloud_gpu: bool) -> None:
def manage_data_directories(env_name: str, base_path: str, use_cloud_gpu: bool) -> None:
# Use the user's home directory as the base path
target_path = os.path.join(os.path.expanduser(base_path), suffix)
target_path = os.path.join(os.path.expanduser(base_path), env_name)
directories = {
"DANSWER_POSTGRES_DATA_DIR": os.path.join(target_path, "postgres/"),
"DANSWER_VESPA_DATA_DIR": os.path.join(target_path, "vespa/"),
@@ -144,12 +144,12 @@ def _is_port_in_use(port: int) -> bool:
def start_docker_compose(
run_suffix: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
env_name: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
) -> None:
print("Starting Docker Compose...")
os.chdir(os.path.dirname(__file__))
os.chdir("../../../../deployment/docker_compose/")
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{run_suffix} up -d"
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{env_name} up -d"
command += " --build"
command += " --force-recreate"
@@ -175,17 +175,17 @@ def start_docker_compose(
print("Containers have been launched")
def cleanup_docker(run_suffix: str) -> None:
def cleanup_docker(env_name: str) -> None:
print(
f"Deleting Docker containers, volumes, and networks for project suffix: {run_suffix}"
f"Deleting Docker containers, volumes, and networks for project env_name: {env_name}"
)
stdout, _ = _run_command("docker ps -a --format '{{json .}}'")
containers = [json.loads(line) for line in stdout.splitlines()]
if not run_suffix:
run_suffix = datetime.now().strftime("-%Y")
project_name = f"danswer-stack{run_suffix}"
if not env_name:
env_name = datetime.now().strftime("-%Y")
project_name = f"danswer-stack{env_name}"
containers_to_delete = [
c for c in containers if c["Names"].startswith(project_name)
]
@@ -221,23 +221,23 @@ def cleanup_docker(run_suffix: str) -> None:
networks = stdout.splitlines()
networks_to_delete = [n for n in networks if run_suffix in n]
networks_to_delete = [n for n in networks if env_name in n]
if not networks_to_delete:
print(f"No networks found containing suffix: {run_suffix}")
print(f"No networks found containing env_name: {env_name}")
else:
network_names = " ".join(networks_to_delete)
_run_command(f"docker network rm {network_names}")
print(
f"Successfully deleted {len(networks_to_delete)} networks containing suffix: {run_suffix}"
f"Successfully deleted {len(networks_to_delete)} networks containing env_name: {env_name}"
)
@retry(tries=5, delay=5, backoff=2)
def get_api_server_host_port(suffix: str) -> str:
def get_api_server_host_port(env_name: str) -> str:
"""
This pulls all containers with the provided suffix
This pulls all containers with the provided env_name
It then grabs the JSON specific container with a name containing "api_server"
It then grabs the port info from the JSON and strips out the relevent data
"""
@@ -248,16 +248,16 @@ def get_api_server_host_port(suffix: str) -> str:
server_jsons = []
for container in containers:
if container_name in container["Names"] and suffix in container["Names"]:
if container_name in container["Names"] and env_name in container["Names"]:
server_jsons.append(container)
if not server_jsons:
raise RuntimeError(
f"No container found containing: {container_name} and {suffix}"
f"No container found containing: {container_name} and {env_name}"
)
elif len(server_jsons) > 1:
raise RuntimeError(
f"Too many containers matching {container_name} found, please indicate a suffix"
f"Too many containers matching {container_name} found, please indicate a env_name"
)
server_json = server_jsons[0]
@@ -278,67 +278,37 @@ def get_api_server_host_port(suffix: str) -> str:
raise RuntimeError(f"Too many ports matching {client_port} found")
if not matching_ports:
raise RuntimeError(
f"No port found containing: {client_port} for container: {container_name} and suffix: {suffix}"
f"No port found containing: {client_port} for container: {container_name} and env_name: {env_name}"
)
return matching_ports[0]
# Added function to check Vespa container health status
def is_vespa_container_healthy(suffix: str) -> bool:
print(f"Checking health status of Vespa container for suffix: {suffix}")
# Find the Vespa container
stdout, _ = _run_command(
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
)
container_name = stdout.strip()
if not container_name:
print(f"No Vespa container found with suffix: {suffix}")
return False
# Get the health status
stdout, _ = _run_command(
f"docker inspect --format='{{{{.State.Health.Status}}}}' {container_name}"
)
health_status = stdout.strip()
is_healthy = health_status.lower() == "healthy"
print(f"Vespa container '{container_name}' health status: {health_status}")
return is_healthy
# Added function to restart Vespa container
def restart_vespa_container(suffix: str) -> None:
print(f"Restarting Vespa container for suffix: {suffix}")
def restart_vespa_container(env_name: str) -> None:
print(f"Restarting Vespa container for env_name: {env_name}")
# Find the Vespa container
stdout, _ = _run_command(
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
f"docker ps -a --format '{{{{.Names}}}}' | awk '/index-1/ && /{env_name}/'"
)
container_name = stdout.strip()
if not container_name:
raise RuntimeError(f"No Vespa container found with suffix: {suffix}")
raise RuntimeError(f"No Vespa container found with env_name: {env_name}")
# Restart the container
_run_command(f"docker restart {container_name}")
print(f"Vespa container '{container_name}' has begun restarting")
time_to_wait = 5
while not is_vespa_container_healthy(suffix):
print(f"Waiting {time_to_wait} seconds for vespa container to restart")
time.sleep(5)
time.sleep(30)
print(f"Vespa container '{container_name}' has been restarted")
if __name__ == "__main__":
"""
Running this just cleans up the docker environment for the container indicated by existing_test_suffix
If no existing_test_suffix is indicated, will just clean up all danswer docker containers/volumes/networks
Running this just cleans up the docker environment for the container indicated by environment_name
If no environment_name is indicated, will just clean up all danswer docker containers/volumes/networks
Note: vespa/postgres mounts are not deleted
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -348,4 +318,4 @@ if __name__ == "__main__":
if not isinstance(config, dict):
raise TypeError("config must be a dictionary")
cleanup_docker(config["existing_test_suffix"])
cleanup_docker(config["environment_name"])

View File

@@ -1,8 +1,13 @@
import os
import tempfile
import time
import zipfile
from pathlib import Path
from types import SimpleNamespace
import yaml
from tests.regression.answer_quality.api_utils import check_indexing_status
from tests.regression.answer_quality.api_utils import create_cc_pair
from tests.regression.answer_quality.api_utils import create_connector
from tests.regression.answer_quality.api_utils import create_credential
@@ -10,15 +15,65 @@ from tests.regression.answer_quality.api_utils import run_cc_once
from tests.regression.answer_quality.api_utils import upload_file
def upload_test_files(zip_file_path: str, run_suffix: str) -> None:
def unzip_and_get_file_paths(zip_file_path: str) -> list[str]:
persistent_dir = tempfile.mkdtemp()
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(persistent_dir)
return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()]
def create_temp_zip_from_files(file_paths: list[str]) -> str:
persistent_dir = tempfile.mkdtemp()
zip_file_path = os.path.join(persistent_dir, "temp.zip")
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
for file_path in file_paths:
zip_file.write(file_path, Path(file_path).name)
return zip_file_path
def upload_test_files(zip_file_path: str, env_name: str) -> None:
print("zip:", zip_file_path)
file_paths = upload_file(run_suffix, zip_file_path)
file_paths = upload_file(env_name, zip_file_path)
conn_id = create_connector(run_suffix, file_paths)
cred_id = create_credential(run_suffix)
conn_id = create_connector(env_name, file_paths)
cred_id = create_credential(env_name)
create_cc_pair(run_suffix, conn_id, cred_id)
run_cc_once(run_suffix, conn_id, cred_id)
create_cc_pair(env_name, conn_id, cred_id)
run_cc_once(env_name, conn_id, cred_id)
def manage_file_upload(zip_file_path: str, env_name: str) -> None:
unzipped_file_paths = unzip_and_get_file_paths(zip_file_path)
total_file_count = len(unzipped_file_paths)
while True:
doc_count, ongoing_index_attempts = check_indexing_status(env_name)
if ongoing_index_attempts:
print(
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
)
elif not doc_count:
print("No docs indexed, waiting for indexing to start")
upload_test_files(zip_file_path, env_name)
elif doc_count < total_file_count:
print(f"No ongooing indexing attempts but only {doc_count} docs indexed")
remaining_files = unzipped_file_paths[doc_count:]
print(f"Grabbed last {len(remaining_files)} docs to try agian")
temp_zip_file_path = create_temp_zip_from_files(remaining_files)
upload_test_files(temp_zip_file_path, env_name)
os.unlink(temp_zip_file_path)
else:
print(f"Successfully uploaded {doc_count} docs!")
break
time.sleep(10)
for file in unzipped_file_paths:
os.unlink(file)
if __name__ == "__main__":
@@ -27,5 +82,5 @@ if __name__ == "__main__":
with open(config_path, "r") as file:
config = SimpleNamespace(**yaml.safe_load(file))
file_location = config.zipped_documents_file
run_suffix = config.existing_test_suffix
upload_test_files(file_location, run_suffix)
env_name = config.environment_name
manage_file_upload(file_location, env_name)

View File

@@ -1,16 +1,12 @@
import os
from datetime import datetime
from types import SimpleNamespace
import yaml
from tests.regression.answer_quality.cli_utils import cleanup_docker
from tests.regression.answer_quality.cli_utils import manage_data_directories
from tests.regression.answer_quality.cli_utils import set_env_variables
from tests.regression.answer_quality.cli_utils import start_docker_compose
from tests.regression.answer_quality.cli_utils import switch_to_commit
from tests.regression.answer_quality.file_uploader import upload_test_files
from tests.regression.answer_quality.run_qa import run_qa_test_and_save_results
def load_config(config_filename: str) -> SimpleNamespace:
@@ -22,12 +18,16 @@ def load_config(config_filename: str) -> SimpleNamespace:
def main() -> None:
config = load_config("search_test_config.yaml")
if config.existing_test_suffix:
run_suffix = config.existing_test_suffix
print("launching danswer with existing data suffix:", run_suffix)
if config.environment_name:
env_name = config.environment_name
print("launching danswer with environment name:", env_name)
else:
run_suffix = datetime.now().strftime("-%Y%m%d-%H%M%S")
print("run_suffix:", run_suffix)
print("No env name defined. Not launching docker.")
print(
"Please define a name in the config yaml to start a new env "
"or use an existing env"
)
return
set_env_variables(
config.model_server_ip,
@@ -35,22 +35,14 @@ def main() -> None:
config.use_cloud_gpu,
config.llm,
)
manage_data_directories(run_suffix, config.output_folder, config.use_cloud_gpu)
manage_data_directories(env_name, config.output_folder, config.use_cloud_gpu)
if config.commit_sha:
switch_to_commit(config.commit_sha)
start_docker_compose(
run_suffix, config.launch_web_ui, config.use_cloud_gpu, config.only_state
env_name, config.launch_web_ui, config.use_cloud_gpu, config.only_state
)
if not config.existing_test_suffix and not config.only_state:
upload_test_files(config.zipped_documents_file, run_suffix)
run_qa_test_and_save_results(run_suffix)
if config.clean_up_docker_containers:
cleanup_docker(run_suffix)
if __name__ == "__main__":
main()

View File

@@ -6,7 +6,6 @@ import time
import yaml
from tests.regression.answer_quality.api_utils import check_if_query_ready
from tests.regression.answer_quality.api_utils import get_answer_from_query
from tests.regression.answer_quality.cli_utils import get_current_commit_sha
from tests.regression.answer_quality.cli_utils import get_docker_container_env_vars
@@ -44,12 +43,12 @@ def _read_questions_jsonl(questions_file_path: str) -> list[dict]:
def _get_test_output_folder(config: dict) -> str:
base_output_folder = os.path.expanduser(config["output_folder"])
if config["run_suffix"]:
if config["env_name"]:
base_output_folder = os.path.join(
base_output_folder, config["run_suffix"], "evaluations_output"
base_output_folder, config["env_name"], "evaluations_output"
)
else:
base_output_folder = os.path.join(base_output_folder, "no_defined_suffix")
base_output_folder = os.path.join(base_output_folder, "no_defined_env_name")
counter = 1
output_folder_path = os.path.join(base_output_folder, "run_1")
@@ -73,12 +72,12 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
metadata = {
"commit_sha": get_current_commit_sha(),
"run_suffix": config["run_suffix"],
"env_name": config["env_name"],
"test_config": config,
"number_of_questions_in_dataset": len(questions),
}
env_vars = get_docker_container_env_vars(config["run_suffix"])
env_vars = get_docker_container_env_vars(config["env_name"])
if env_vars["ENV_SEED_CONFIGURATION"]:
del env_vars["ENV_SEED_CONFIGURATION"]
if env_vars["GPG_KEY"]:
@@ -118,7 +117,8 @@ def _process_question(question_data: dict, config: dict, question_number: int) -
print(f"query: {query}")
context_data_list, answer = get_answer_from_query(
query=query,
run_suffix=config["run_suffix"],
only_retrieve_docs=config["only_retrieve_docs"],
env_name=config["env_name"],
)
if not context_data_list:
@@ -142,27 +142,23 @@ def _process_and_write_query_results(config: dict) -> None:
test_output_folder, questions = _initialize_files(config)
print("saving test results to folder:", test_output_folder)
while not check_if_query_ready(config["run_suffix"]):
time.sleep(5)
if config["limit"] is not None:
questions = questions[: config["limit"]]
with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
# Use multiprocessing to process questions
with multiprocessing.Pool() as pool:
results = pool.starmap(
_process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
_process_question,
[(question, config, i + 1) for i, question in enumerate(questions)],
)
_populate_results_file(test_output_folder, results)
invalid_answer_count = 0
for result in results:
if not result.get("answer"):
if len(result["context_data_list"]) == 0:
invalid_answer_count += 1
if not result.get("context_data_list"):
raise RuntimeError("Search failed, this is a critical failure!")
_update_metadata_file(test_output_folder, invalid_answer_count)
if invalid_answer_count:
@@ -177,7 +173,7 @@ def _process_and_write_query_results(config: dict) -> None:
print("saved test results to folder:", test_output_folder)
def run_qa_test_and_save_results(run_suffix: str = "") -> None:
def run_qa_test_and_save_results(env_name: str = "") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, "search_test_config.yaml")
with open(config_path, "r") as file:
@@ -186,16 +182,16 @@ def run_qa_test_and_save_results(run_suffix: str = "") -> None:
if not isinstance(config, dict):
raise TypeError("config must be a dictionary")
if not run_suffix:
run_suffix = config["existing_test_suffix"]
if not env_name:
env_name = config["environment_name"]
config["run_suffix"] = run_suffix
config["env_name"] = env_name
_process_and_write_query_results(config)
if __name__ == "__main__":
"""
To run a different set of questions, update the questions_file in search_test_config.yaml
If there is more than one instance of Danswer running, specify the suffix in search_test_config.yaml
If there is more than one instance of Danswer running, specify the env_name in search_test_config.yaml
"""
run_qa_test_and_save_results()

View File

@@ -13,14 +13,11 @@ questions_file: "~/sample_questions.yaml"
# Git commit SHA to use (null means use current code as is)
commit_sha: null
# Whether to remove Docker containers after the test
clean_up_docker_containers: true
# Whether to launch a web UI for the test
launch_web_ui: false
# Whether to only run Vespa and Postgres
only_state: false
# Only retrieve documents, not LLM response
only_retrieve_docs: false
# Whether to use a cloud GPU for processing
use_cloud_gpu: false
@@ -31,9 +28,8 @@ model_server_ip: "PUT_PUBLIC_CLOUD_IP_HERE"
# Port of the model server (placeholder)
model_server_port: "PUT_PUBLIC_CLOUD_PORT_HERE"
# Suffix for existing test results (E.g. -1234-5678)
# empty string means no suffix
existing_test_suffix: ""
# Name for existing testing env (empty string uses default ports)
environment_name: ""
# Limit on number of tests to run (null means no limit)
limit: null

View File

@@ -31,8 +31,8 @@ def test_chunk_document() -> None:
chunks = chunk_document(document)
assert len(chunks) == 5
assert all(semantic_identifier in chunk.content for chunk in chunks)
assert short_section_1 in chunks[0].content
assert short_section_3 in chunks[-1].content
assert short_section_4 in chunks[-1].content
assert "tag1" in chunks[0].content
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic

View File

@@ -25,8 +25,8 @@ services:
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
- REQUIRE_EMAIL_VERIFICATION=${REQUIRE_EMAIL_VERIFICATION:-}
- SMTP_SERVER=${SMTP_SERVER:-} # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com'
- SMTP_PORT=${SMTP_PORT:-587} # For sending verification emails, if unspecified then defaults to '587'
- SMTP_SERVER=${SMTP_SERVER:-} # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com'
- SMTP_PORT=${SMTP_PORT:-587} # For sending verification emails, if unspecified then defaults to '587'
- SMTP_USER=${SMTP_USER:-}
- SMTP_PASS=${SMTP_PASS:-}
- EMAIL_FROM=${EMAIL_FROM:-}
@@ -59,8 +59,8 @@ services:
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
# Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
@@ -69,7 +69,7 @@ services:
# Other services
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
# Don't change the NLP model configs unless you know what you're doing
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
@@ -104,7 +104,6 @@ services:
max-size: "50m"
max-file: "6"
background:
image: danswer/danswer-backend:latest
build:
@@ -139,8 +138,8 @@ services:
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
- BING_API_KEY=${BING_API_KEY:-}
# Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
@@ -152,12 +151,12 @@ services:
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- VESPA_HOST=index
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors
# Don't change the NLP model configs unless you know what you're doing
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
@@ -183,7 +182,7 @@ services:
- DANSWER_BOT_FEEDBACK_VISIBILITY=${DANSWER_BOT_FEEDBACK_VISIBILITY:-}
- DANSWER_BOT_DISPLAY_ERROR_MSGS=${DANSWER_BOT_DISPLAY_ERROR_MSGS:-}
- DANSWER_BOT_RESPOND_EVERY_CHANNEL=${DANSWER_BOT_RESPOND_EVERY_CHANNEL:-}
- DANSWER_BOT_DISABLE_COT=${DANSWER_BOT_DISABLE_COT:-} # Currently unused
- DANSWER_BOT_DISABLE_COT=${DANSWER_BOT_DISABLE_COT:-} # Currently unused
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
@@ -207,7 +206,6 @@ services:
max-size: "50m"
max-file: "6"
web_server:
image: danswer/danswer-web-server:latest
build:
@@ -219,6 +217,7 @@ services:
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
# Enterprise Edition only
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
@@ -236,7 +235,6 @@ services:
# Enterprise Edition only
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
inference_model_server:
image: danswer/danswer-model-server:latest
build:
@@ -263,7 +261,6 @@ services:
max-size: "50m"
max-file: "6"
indexing_model_server:
image: danswer/danswer-model-server:latest
build:
@@ -291,7 +288,6 @@ services:
max-size: "50m"
max-file: "6"
relational_db:
image: postgres:15.2-alpine
restart: always
@@ -303,7 +299,6 @@ services:
volumes:
- db_volume:/var/lib/postgresql/data
# This container name cannot have an underscore in it due to Vespa expectations of the URL
index:
image: vespaengine/vespa:8.277.17
@@ -319,7 +314,6 @@ services:
max-size: "50m"
max-file: "6"
nginx:
image: nginx:1.23.4-alpine
restart: always
@@ -332,7 +326,7 @@ services:
- DOMAIN=localhost
ports:
- "80:80"
- "3000:80" # allow for localhost:3000 usage, since that is the norm
- "3000:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
logging:
@@ -349,10 +343,9 @@ services:
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
volumes:
db_volume:
vespa_volume:
# Created by the container itself
vespa_volume: # Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:

View File

@@ -211,6 +211,7 @@ services:
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
depends_on:
- api_server

View File

@@ -70,6 +70,7 @@ services:
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
depends_on:
- api_server

View File

@@ -75,6 +75,7 @@ services:
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
# Enterprise Edition only
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}

View File

@@ -46,6 +46,9 @@ ENV NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PRED
ARG NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS
ENV NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS}
ARG NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN
ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
ARG NEXT_PUBLIC_THEME
ENV NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME}
@@ -106,6 +109,9 @@ ENV NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME}
ARG NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED
ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED}
ARG NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN
ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
ARG NEXT_PUBLIC_DISABLE_LOGOUT
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}

View File

@@ -436,7 +436,7 @@ export function AssistantEditor({
<div className="mb-6">
<div className="flex gap-x-2 items-center">
<div className="block font-medium text-base">
LLM Provider{" "}
LLM Override{" "}
</div>
<TooltipProvider delayDuration={50}>
<Tooltip>
@@ -452,6 +452,10 @@ export function AssistantEditor({
</Tooltip>
</TooltipProvider>
</div>
<p className="my-1 text-text-600">
You assistant will use your system default (currently{" "}
{defaultModelName}) unless otherwise specified below.
</p>
<div className="mb-2 flex items-starts">
<div className="w-96">
<SelectorFormField

View File

@@ -136,10 +136,6 @@ export default function Page({ params }: { params: { ccPairId: string } }) {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<Main ccPairId={ccPairId} />
</div>
);

View File

@@ -248,10 +248,6 @@ const MainSection = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<AxeroIcon size={32} />} title="Axero" />
<MainSection />

View File

@@ -249,10 +249,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<BookstackIcon size={32} />} title="Bookstack" />
<Main />

View File

@@ -331,10 +331,6 @@ const MainSection = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<ClickupIcon size={32} />} title="Clickup" />
<MainSection />

View File

@@ -330,10 +330,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<ConfluenceIcon size={32} />} title="Confluence" />
<Main />

View File

@@ -273,10 +273,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<DiscourseIcon size={32} />} title="Discourse" />
<Main />

View File

@@ -262,10 +262,6 @@ const MainSection = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle
icon={<Document360Icon size={32} />}
title="Document360"

View File

@@ -212,9 +212,7 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
{" "}
<AdminPageTitle icon={<DropboxIcon size={32} />} title="Dropbox" />
<Main />
</div>

View File

@@ -287,10 +287,6 @@ const Main = () => {
export default function File() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<FileIcon size={32} />} title="File" />
<Main />

View File

@@ -265,10 +265,6 @@ const Main = () => {
export default function Page() {
return (
<div className="container mx-auto">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle
icon={<GithubIcon size={32} />}
title="Github PRs + Issues"

View File

@@ -250,10 +250,6 @@ const Main = () => {
export default function Page() {
return (
<div className="container mx-auto">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle
icon={<GitlabIcon size={32} />}
title="Gitlab MRs + Issues"

View File

@@ -264,10 +264,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<GmailIcon size={32} />} title="Gmail" />
<Main />

View File

@@ -257,10 +257,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<GongIcon size={32} />} title="Gong" />
<Main />

View File

@@ -412,10 +412,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle
icon={<GoogleDriveIcon size={32} />}
title="Google Drive"

View File

@@ -50,10 +50,6 @@ export default function GoogleSites() {
{popup}
{filesAreUploading && <Spinner />}
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle
icon={<GoogleSitesIcon size={32} />}
title="Google Sites"

View File

@@ -244,9 +244,7 @@ const GCSMain = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
{" "}
<AdminPageTitle
icon={<GoogleStorageIcon size={32} />}
title="Google Cloud Storage"

View File

@@ -232,10 +232,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<GuruIcon size={32} />} title="Guru" />
<Main />

View File

@@ -220,10 +220,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<HubSpotIcon size={32} />} title="HubSpot" />
<Main />

View File

@@ -362,10 +362,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<JiraIcon size={32} />} title="Jira" />
<Main />

View File

@@ -224,10 +224,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<LinearIcon size={32} />} title="Linear" />
<Main />

View File

@@ -250,9 +250,7 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
{" "}
<div className="border-solid border-gray-600 border-b mb-4 pb-2 flex">
<LoopioIcon size={32} />
<h1 className="text-3xl font-bold pl-2">Loopio</h1>

View File

@@ -207,10 +207,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<MediaWikiIcon size={32} />} title="MediaWiki" />
<Main />

View File

@@ -260,10 +260,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<NotionIcon size={32} />} title="Notion" />
<Main />

View File

@@ -259,9 +259,7 @@ const OCIMain = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
{" "}
<AdminPageTitle
icon={<OCIStorageIcon size={32} />}
title="Oracle Cloud Infrastructure"

View File

@@ -239,10 +239,6 @@ const Main = () => {
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle
icon={<ProductboardIcon size={32} />}
title="Productboard"

Some files were not shown because too many files have changed in this diff Show More