Compare commits

..

3 Commits

Author SHA1 Message Date
pablodanswer
c68602f456 specifically apply flex none to in progress! 2024-11-10 18:43:22 -08:00
rkuo-danswer
9d57f34c34 re-enable helm (#3053)
* re-enable helm

* allow manual triggering

* change vespa host

* change vespa chart location

* update Chart.lock

* update ct.yaml with new vespa chart repo

* bump vespa to 0.2.5

* update Chart.lock

* update to vespa 0.2.6

* bump vespa to 0.2.7

* bump to 0.2.8

* bump version

* try appending the ordinal

* try new configmap

* bump vespa

* bump vespa

* add debug to see if we can figure out what ct install thinks is failing

* add debug flag to helm

* try disabling nginx because of KinD

* use helm-extra-set-args

* try command line

* try pointing test connection to the correct service name

* bump vespa to 0.2.12

* update chart.lock

* bump vespa to 0.2.13

* bump vespa to 0.2.14

* bump vespa

* bump vespa

* re-enable chart testing only on changes

* name the check more specifically than "lint-test"

* add some debugging

* try setting remote

* might have to specify chart dirs directly

* add comments

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-10 01:28:39 +00:00
pablodanswer
cc2f584321 Silence auth logs (#3098)
* silence auth logs

* remove unnecessary line

* k
2024-11-09 21:41:11 +00:00
46 changed files with 11157 additions and 6106 deletions

View File

@@ -288,15 +288,6 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -23,56 +23,6 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -12,7 +12,6 @@ from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -73,15 +72,6 @@ class DynamicTenantScheduler(PersistentScheduler):
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")

View File

@@ -6,7 +6,6 @@ from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
@@ -82,11 +81,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any

View File

@@ -0,0 +1,96 @@
from datetime import timedelta
from typing import Any
from celery.beat import PersistentScheduler # type: ignore
from celery.utils.log import get_task_logger
from danswer.db.engine import get_all_tenant_ids
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = get_task_logger(__name__)
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=1)
self._last_reload = self.app.now() - self._reload_interval
def setup_schedule(self) -> None:
super().setup_schedule()
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reloading schedule to check for new tenants...")
self._update_tenant_tasks()
self._last_reload = now
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Checking for tenant task updates...")
try:
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = getattr(self, "_store", {"entries": {}}).get(
"entries", {}
)
existing_tenants = set()
for task_name in current_schedule.keys():
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Found new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Updating schedule",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
if not hasattr(self, "_store"):
self._store: dict[str, dict] = {"entries": {}}
self.update_from_dict(new_beat_schedule)
logger.info(f"New schedule: {new_beat_schedule}")
logger.info("Tenant tasks updated successfully")
else:
logger.debug("No schedule updates needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
current_tasks = set(current_schedule.keys())
new_tasks = set(new_schedule.keys())
return current_tasks != new_tasks

View File

@@ -8,7 +8,7 @@ tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=20),
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
@@ -20,13 +20,13 @@ tasks_to_schedule = [
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=15),
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=15),
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{

View File

@@ -29,26 +29,18 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
"""Ensure the parent proc's database connections are not touched
in the new connection pool
Based on SQLAlchemy's recommendations to handle multiprocessing:
Based on the recommended approach in the SQLAlchemy docs found:
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
"""
if kwargs is None:
kwargs = {}
logger.info("Initializing spawned worker child process.")
# Reset the engine in the child process
SqlEngine.reset_engine()
# Optionally set a custom app name for database logging purposes
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# Initialize a new engine with desired parameters
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
# Proceed with executing the target function
return func(*args, **kwargs)

View File

@@ -19,10 +19,16 @@ from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -35,6 +41,7 @@ from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
@@ -54,13 +61,14 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
@@ -69,14 +77,14 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.force import ForceUseTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_constructor import construct_tools
from danswer.tools.tool_constructor import CustomToolConfig
from danswer.tools.tool_constructor import ImageGenerationToolConfig
from danswer.tools.tool_constructor import InternetSearchToolConfig
from danswer.tools.tool_constructor import SearchToolConfig
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
@@ -87,6 +95,9 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
@@ -111,6 +122,9 @@ from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -281,6 +295,7 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -292,9 +307,6 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
@@ -416,20 +428,12 @@ def stream_chat_message_objects(
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if existing_assistant_message_id is None:
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
else:
if final_msg.id != existing_assistant_message_id:
raise RuntimeError(
"The last message was not the existing assistant message. "
f"Final message id: {final_msg.id}, "
f"existing assistant message id: {existing_assistant_message_id}"
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
@@ -500,19 +504,13 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
# we don't need to reserve a message id if we're using an existing assistant message
reserved_message_id = (
final_msg.id
if existing_assistant_message_id is not None
else reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
@@ -527,13 +525,7 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
parent_message=final_msg,
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
@@ -545,7 +537,6 @@ def stream_chat_message_objects(
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
@@ -569,39 +560,142 @@ def stream_chat_message_objects(
structured_response_format=new_msg_req.structured_response_format,
)
tool_dict = construct_tools(
persona=persona,
prompt_config=prompt_config,
db_session=db_session,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_style_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=bing_api_key,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
@@ -777,6 +871,7 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -784,11 +879,9 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=(
message_specific_citations.citation_map
if message_specific_citations
else None
),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
error=None,
tool_call=(
ToolCall(
@@ -822,6 +915,7 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -831,6 +925,7 @@ def stream_chat_message(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,

View File

@@ -55,11 +55,11 @@ def validate_channel_names(
# Scaling configurations for multi-tenant Slack bot handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
15 # How often pods send heartbeats to indicate they are still processing a tenant
60 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = (
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_ACQUISITION_INTERVAL = (
60 # How often pods attempt to acquire unprocessed tenants
)
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))

View File

@@ -75,7 +75,6 @@ from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -165,15 +164,9 @@ class SlackbotHandler:
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
for tenant_id in tenant_ids:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
continue
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
@@ -197,9 +190,6 @@ class SlackbotHandler:
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
self.tenant_ids.add(tenant_id)
for tenant_id in self.tenant_ids:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
@@ -246,14 +236,14 @@ class SlackbotHandler:
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if self.socket_clients.get(tenant_id):
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if self.socket_clients.get(tenant_id):
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
@@ -287,14 +277,14 @@ class SlackbotHandler:
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
if client:
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
@@ -308,16 +298,6 @@ class SlackbotHandler:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Release locks for all tenants
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)

View File

@@ -24,13 +24,6 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
return tool
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool
def create_tool(
name: str,
description: str | None,
@@ -44,7 +37,7 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.model_dump() for header in custom_headers]
custom_headers=[header.dict() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,

View File

@@ -74,9 +74,6 @@ from danswer.server.manage.search_settings import router as search_settings_rout
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
)
from danswer.server.query_and_chat.chat_backend import router as chat_router
from danswer.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@@ -273,9 +270,6 @@ def get_application() -> FastAPI:
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@@ -63,7 +63,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (

File diff suppressed because it is too large Load Diff

View File

@@ -1,44 +0,0 @@
[
{
"url": "https://docs.danswer.dev/more/use_cases/overview",
"title": "Use Cases Overview",
"content": "How to leverage Danswer in your organization\n\nDanswer Overview\nDanswer is the AI Assistant connected to your organization's docs, apps, and people. Danswer makes Generative AI more versatile for work by enabling new types of questions like \"What is the most common feature request we've heard from customers this month\". Whereas other AI systems have no context of your team and are generally unhelpful with work related questions, Danswer makes it possible to ask these questions in natural language and get back answers in seconds.\n\nDanswer can connect to +30 different tools and the use cases are not limited to the ones in the following pages. The highlighted use cases are for inspiration and come from feedback gathered from our users and customers.\n\n\nCommon Getting Started Questions:\n\nWhy are these docs connected in my Danswer deployment?\nAnswer: This is just an example of how connectors work in Danswer. You can connect up your own team's knowledge and you will be able to ask questions unique to your organization. Danswer will keep all of the knowledge up to date and in sync with your connected applications.\n\nIs my data being sent anywhere when I connect it up to Danswer?\nAnswer: No! Danswer is built with data security as our highest priority. We open sourced it so our users can know exactly what is going on with their data. By default all of the document processing happens within Danswer. The only time it is sent outward is for the GenAI call to generate answers.\n\nWhere is the feature for auto sync-ing document level access permissions from all connected sources?\nAnswer: This falls under the Enterprise Edition set of Danswer features built on top of the MIT/community edition. If you are on Danswer Cloud, you have access to them by default. If you're running it yourself, reach out to the Danswer team to receive access.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
"title": "Enterprise Search",
"content": "Value of Enterprise Search with Danswer\n\nWhat is Enterprise Search and why is it Important?\nAn Enterprise Search system gives team members a single place to access all of the disparate knowledge of an organization. Critical information is saved across a host of channels like call transcripts with prospects, engineering design docs, IT runbooks, customer support email exchanges, project management tickets, and more. As fast moving teams scale up, information gets spread out and more disorganized.\n\nSince it quickly becomes infeasible to check across every source, decisions get made on incomplete information, employee satisfaction decreases, and the most valuable members of your team are tied up with constant distractions as junior teammates are unable to unblock themselves. Danswer solves this problem by letting anyone on the team access all of the knowledge across your organization in a permissioned and secure way. Users can ask questions in natural language and get back answers and documents across all of the connected sources instantly.\n\nWhat's the real cost?\nA typical knowledge worker spends over 2 hours a week on search, but more than that, the cost of incomplete or incorrect information can be extremely high. Customer support/success that isn't able to find the reference to similar cases could cause hours or even days of delay leading to lower customer satisfaction or in the worst case - churn. An account exec not realizing that a prospect had previously mentioned a specific need could lead to lost deals. An engineer not realizing a similar feature had previously been built could result in weeks of wasted development time and tech debt with duplicate implementation. With a lack of knowledge, your whole organization is navigating in the dark - inefficient and mistake prone.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
"title": "Enterprise Search",
"content": "More than Search\nWhen analyzing the entire corpus of knowledge within your company is as easy as asking a question in a search bar, your entire team can stay informed and up to date. Danswer also makes it trivial to identify where knowledge is well documented and where it is lacking. Team members who are centers of knowledge can begin to effectively document their expertise since it is no longer being thrown into a black hole. All of this allows the organization to achieve higher efficiency and drive business outcomes.\n\nWith Generative AI, the entire user experience has evolved as well. For example, instead of just finding similar cases for your customer support team to reference, Danswer breaks down the issue and explains it so that even the most junior members can understand it. This in turn lets them give the most holistic and technically accurate response possible to your customers. On the other end, even the super stars of your sales team will not be able to review 10 hours of transcripts before hopping on that critical call, but Danswer can easily parse through it in mere seconds and give crucial context to help your team close.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/ai_platform",
"title": "AI Platform",
"content": "Build AI Agents powered by the knowledge and workflows specific to your organization.\n\nBeyond Answers\nAgents enabled by generative AI and reasoning capable models are helping teams to automate their work. Danswer is helping teams make it happen. Danswer provides out of the box user chat sessions, attaching custom tools, handling LLM reasoning, code execution, data analysis, referencing internal knowledge, and much more.\n\nDanswer as a platform is not a no-code agent builder. We are made by developers for developers and this gives your team the full flexibility and power to create agents not constrained by blocks and simple logic paths.\n\nFlexibility and Extensibility\nDanswer is open source and completely whitebox. This not only gives transparency to what happens within the system but also means that your team can directly modify the source code to suit your unique needs.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/customer_support",
"title": "Customer Support",
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nDanswer takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/sales",
"title": "Sales",
"content": "Keep your team up to date on every conversation and update so they can close.\n\nRecall Every Detail\nBeing able to instantly revisit every detail of any call without reading transcripts is helping Sales teams provide more tailored pitches, build stronger relationships, and close more deals. Instead of searching and reading through hours of transcripts in preparation for a call, your team can now ask Danswer \"What specific features was ACME interested in seeing for the demo\". Since your team doesn't have time to read every transcript prior to a call, Danswer provides a more thorough summary because it can instantly parse hundreds of pages and distill out the relevant information. Even for fast lookups it becomes much more convenient - for example to brush up on connection building topics by asking \"What rapport building topic did we chat about in the last call with ACME\".\n\nKnow Every Product Update\nIt is impossible for Sales teams to keep up with every product update. Because of this, when a prospect has a question that the Sales team does not know, they have no choice but to rely on the Product and Engineering orgs to get an authoritative answer. Not only is this distracting to the other teams, it also slows down the time to respond to the prospect (and as we know, time is the biggest killer of deals). With Danswer, it is even possible to get answers live on call because of how fast accessing information becomes. A question like \"Have we shipped the Microsoft AD integration yet?\" can now be answered in seconds meaning that prospects can get answers while on the call instead of asynchronously and sales cycles are reduced as a result.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/operations",
"title": "Operations",
"content": "Double the productivity of your Ops teams like IT, HR, etc.\n\nAutomatically Resolve Tickets\nModern teams are leveraging AI to auto-resolve up to 50% of tickets. Whether it is an employee asking about benefits details or how to set up the VPN for remote work, Danswer can help your team help themselves. This frees up your team to do the real impactful work of landing star candidates or improving your internal processes.\n\nAI Aided Onboarding\nOne of the periods where your team needs the most help is when they're just ramping up. Instead of feeling lost in dozens of new tools, Danswer gives them a single place where they can ask about anything in natural language. Whether it's how to set up their work environment or what their onboarding goals are, Danswer can walk them through every step with the help of Generative AI. This lets your team feel more empowered and gives time back to the more seasoned members of your team to focus on moving the needle.",
"chunk_ind": 0
}
]

View File

@@ -3,7 +3,6 @@ import json
import os
from typing import cast
from cohere import Client
from sqlalchemy.orm import Session
from danswer.access.models import default_public_access
@@ -33,7 +32,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.documents.models import ConnectorBase
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
logger = setup_logger()
@@ -92,9 +91,7 @@ def _create_indexable_chunks(
return list(ids_to_documents.values()), chunks
def seed_initial_documents(
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
) -> None:
def seed_initial_documents(db_session: Session, tenant_id: str | None) -> None:
"""
Seed initial documents so users don't have an empty index to start
@@ -135,9 +132,7 @@ def seed_initial_documents(
return
search_settings = get_current_search_settings(db_session)
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL and not (
search_settings.model_name == "embed-english-v3.0" and cohere_enabled
):
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL:
logger.info("Embedding model has been updated, skipping")
return
@@ -178,31 +173,10 @@ def seed_initial_documents(
)
cc_pair_id = cast(int, result.data)
if cohere_enabled:
initial_docs_path = os.path.join(
os.getcwd(), "danswer", "seeding", "initial_docs_cohere.json"
)
cohere_client = Client(COHERE_DEFAULT_API_KEY)
processed_docs = json.load(open(initial_docs_path))
for doc in processed_docs:
title_embedding = cohere_client.embed(
texts=[doc["title"]], model="embed-english-v3.0"
).embeddings[0]
content_embedding = cohere_client.embed(
texts=[doc["content"]], model="embed-english-v3.0"
).embeddings[0]
doc["title_embedding"] = title_embedding
doc["content_embedding"] = content_embedding
else:
initial_docs_path = os.path.join(
os.getcwd(),
"danswer",
"seeding",
"initial_docs.json",
)
processed_docs = json.load(open(initial_docs_path))
initial_docs_path = os.path.join(
os.getcwd(), "danswer", "seeding", "initial_docs.json"
)
processed_docs = json.load(open(initial_docs_path))
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)

View File

@@ -1,273 +0,0 @@
from typing import Any
from typing import Optional
from uuid import uuid4
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.db.tools import get_tool_by_name
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/assistants")
# Base models
class AssistantObject(BaseModel):
id: int
object: str = "assistant"
created_at: int
name: Optional[str] = None
description: Optional[str] = None
model: str
instructions: Optional[str] = None
tools: list[dict[str, Any]]
file_ids: list[str]
metadata: Optional[dict[str, Any]] = None
class CreateAssistantRequest(BaseModel):
model: str
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None
class ModifyAssistantRequest(BaseModel):
model: Optional[str] = None
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None
class DeleteAssistantResponse(BaseModel):
id: int
object: str = "assistant.deleted"
deleted: bool
class ListAssistantsResponse(BaseModel):
object: str = "list"
data: list[AssistantObject]
first_id: Optional[int] = None
last_id: Optional[int] = None
has_more: bool
def persona_to_assistant(persona: Persona) -> AssistantObject:
return AssistantObject(
id=persona.id,
created_at=0,
name=persona.name,
description=persona.description,
model=persona.llm_model_version_override or "gpt-3.5-turbo",
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
tools=[
{
"type": tool.display_name,
"function": {
"name": tool.name,
"description": tool.description,
"schema": tool.openapi_schema,
},
}
for tool in persona.tools
],
file_ids=[], # Assuming no file support for now
metadata={}, # Assuming no metadata for now
)
# API endpoints
@router.post("")
def create_assistant(
request: CreateAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
prompt = None
if request.instructions:
prompt = upsert_prompt(
user=user,
name=f"Prompt for {request.name or 'New Assistant'}",
description="Auto-generated prompt",
system_prompt=request.instructions,
task_prompt="",
include_citations=True,
datetime_aware=True,
personas=[],
db_session=db_session,
)
tool_ids = []
for tool in request.tools or []:
tool_type = tool.get("type")
if not tool_type:
continue
try:
tool_db = get_tool_by_name(tool_type, db_session)
tool_ids.append(tool_db.id)
except ValueError:
# Skip tools that don't exist in the database
logger.error(f"Tool {tool_type} not found in database")
raise HTTPException(
status_code=404, detail=f"Tool {tool_type} not found in database"
)
persona = upsert_persona(
user=user,
name=request.name or f"Assistant-{uuid4()}",
description=request.description or "",
num_chunks=25,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
llm_model_provider_override=None,
llm_model_version_override=request.model,
starter_messages=None,
is_public=False,
db_session=db_session,
prompt_ids=[prompt.id] if prompt else [0],
document_set_ids=[],
tool_ids=tool_ids,
icon_color=None,
icon_shape=None,
is_visible=True,
)
if prompt:
prompt.personas = [persona]
db_session.commit()
return persona_to_assistant(persona)
""
@router.get("/{assistant_id}")
def retrieve_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
try:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
except ValueError:
persona = None
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
return persona_to_assistant(persona)
@router.post("/{assistant_id}")
def modify_assistant(
assistant_id: int,
request: ModifyAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=True,
)
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
update_data = request.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(persona, key, value)
if "instructions" in update_data and persona.prompts:
persona.prompts[0].system_prompt = update_data["instructions"]
db_session.commit()
return persona_to_assistant(persona)
@router.delete("/{assistant_id}")
def delete_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> DeleteAssistantResponse:
try:
mark_persona_as_deleted(
persona_id=int(assistant_id),
user=user,
db_session=db_session,
)
return DeleteAssistantResponse(id=assistant_id, deleted=True)
except ValueError:
raise HTTPException(status_code=404, detail="Assistant not found")
@router.get("")
def list_assistants(
limit: int = Query(20, le=100),
order: str = Query("desc", regex="^(asc|desc)$"),
after: Optional[int] = None,
before: Optional[int] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListAssistantsResponse:
personas = list(
get_personas(
user=user,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
)
# Apply filtering based on after and before
if after:
personas = [p for p in personas if p.id > int(after)]
if before:
personas = [p for p in personas if p.id < int(before)]
# Apply ordering
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
# Apply limit
personas = personas[:limit]
assistants = [persona_to_assistant(p) for p in personas]
return ListAssistantsResponse(
data=assistants,
first_id=assistants[0].id if assistants else None,
last_id=assistants[-1].id if assistants else None,
has_more=len(personas) == limit,
)

View File

@@ -1,19 +0,0 @@
from fastapi import APIRouter
from danswer.server.openai_assistants_api.asssistants_api import (
router as assistants_router,
)
from danswer.server.openai_assistants_api.messages_api import router as messages_router
from danswer.server.openai_assistants_api.runs_api import router as runs_router
from danswer.server.openai_assistants_api.threads_api import router as threads_router
def get_full_openai_assistants_api_router() -> APIRouter:
router = APIRouter(prefix="/openai-assistants")
router.include_router(assistants_router)
router.include_router(runs_router)
router.include_router(threads_router)
router.include_router(messages_router)
return router

View File

@@ -1,235 +0,0 @@
import uuid
from datetime import datetime
from typing import Any
from typing import Literal
from typing import Optional
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.constants import MessageType
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.llm.utils import check_number_of_tokens
router = APIRouter(prefix="")
Role = Literal["user", "assistant"]
class MessageContent(BaseModel):
type: Literal["text"]
text: str
class Message(BaseModel):
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
object: Literal["thread.message"] = "thread.message"
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
thread_id: str
role: Role
content: list[MessageContent]
file_ids: list[str] = []
assistant_id: Optional[str] = None
run_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
class CreateMessageRequest(BaseModel):
role: Role
content: str
file_ids: list[str] = []
metadata: Optional[dict] = None
class ListMessagesResponse(BaseModel):
object: Literal["list"] = "list"
data: list[Message]
first_id: str
last_id: str
has_more: bool
@router.post("/threads/{thread_id}/messages")
def create_message(
thread_id: str,
message: CreateMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message=message.content,
prompt_id=chat_session.persona.prompts[0].id,
token_count=check_number_of_tokens(message.content),
message_type=(
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
),
db_session=db_session,
)
return Message(
id=str(new_message.id),
thread_id=thread_id,
role="user",
content=[MessageContent(type="text", text=message.content)],
file_ids=message.file_ids,
metadata=message.metadata,
)
@router.get("/threads/{thread_id}/messages")
def list_messages(
thread_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListMessagesResponse:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user_id,
db_session=db_session,
)
# Apply filtering based on after and before
if after:
messages = [m for m in messages if str(m.id) >= after]
if before:
messages = [m for m in messages if str(m.id) <= before]
# Apply ordering
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
# Apply limit
messages = messages[:limit]
data = [
Message(
id=str(m.id),
thread_id=thread_id,
role="user" if m.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=m.message)],
created_at=int(m.time_sent.timestamp()),
)
for m in messages
]
return ListMessagesResponse(
data=data,
first_id=str(data[0].id) if data else "",
last_id=str(data[-1].id) if data else "",
has_more=len(messages) == limit,
)
@router.get("/threads/{thread_id}/messages/{message_id}")
def retrieve_message(
thread_id: str,
message_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
)
class ModifyMessageRequest(BaseModel):
metadata: dict
@router.post("/threads/{thread_id}/messages/{message_id}")
def modify_message(
thread_id: str,
message_id: int,
request: ModifyMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
# Update metadata
# TODO: Uncomment this once we have metadata in the chat message
# chat_message.metadata = request.metadata
# db_session.commit()
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
metadata=request.metadata,
)

View File

@@ -1,344 +0,0 @@
from typing import Literal
from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from fastapi import BackgroundTasks
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.process_message import stream_chat_message_objects
from danswer.configs.constants import MessageType
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.engine import get_session
from danswer.db.models import ChatMessage
from danswer.db.models import User
from danswer.search.models import RetrievalDetails
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter()
class RunRequest(BaseModel):
assistant_id: int
model: Optional[str] = None
instructions: Optional[str] = None
additional_instructions: Optional[str] = None
tools: Optional[list[dict]] = None
metadata: Optional[dict] = None
RunStatus = Literal[
"queued",
"in_progress",
"requires_action",
"cancelling",
"cancelled",
"failed",
"completed",
"expired",
]
class RunResponse(BaseModel):
id: str
object: Literal["thread.run"]
created_at: int
assistant_id: int
thread_id: UUID
status: RunStatus
started_at: Optional[int] = None
expires_at: Optional[int] = None
cancelled_at: Optional[int] = None
failed_at: Optional[int] = None
completed_at: Optional[int] = None
last_error: Optional[dict] = None
model: str
instructions: str
tools: list[dict]
file_ids: list[str]
metadata: Optional[dict] = None
def process_run_in_background(
message_id: int,
parent_message_id: int,
chat_session_id: UUID,
assistant_id: int,
instructions: str,
tools: list[dict],
user: User | None,
db_session: Session,
) -> None:
# Get the latest message in the chat session
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=user.id if user else None,
db_session=db_session,
)
search_tool_retrieval_details = RetrievalDetails()
for tool in tools:
if tool["type"] == SearchTool.__name__ and (
retrieval_details := tool.get("retrieval_details")
):
search_tool_retrieval_details = RetrievalDetails.model_validate(
retrieval_details
)
break
new_msg_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=int(parent_message_id) if parent_message_id else None,
message=instructions,
file_descriptors=[],
prompt_id=chat_session.persona.prompts[0].id,
search_doc_ids=None,
retrieval_options=search_tool_retrieval_details, # Adjust as needed
query_override=None,
regenerate=None,
llm_override=None,
prompt_override=None,
alternate_assistant_id=assistant_id,
use_existing_user_message=True,
existing_assistant_message_id=message_id,
)
run_message = get_chat_message(message_id, user.id if user else None, db_session)
try:
for packet in stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
):
if isinstance(packet, ChatMessageDetail):
# Update the run status and message content
run_message = get_chat_message(
message_id, user.id if user else None, db_session
)
if run_message:
# this handles cancelling
if run_message.error:
return
run_message.message = packet.message
run_message.message_type = MessageType.ASSISTANT
db_session.commit()
except Exception as e:
logger.exception("Error processing run in background")
run_message.error = str(e)
db_session.commit()
return
db_session.refresh(run_message)
if run_message.token_count == 0:
run_message.error = "No tokens generated"
db_session.commit()
@router.post("/threads/{thread_id}/runs")
def create_run(
thread_id: UUID,
run_request: RunRequest,
background_tasks: BackgroundTasks,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
try:
chat_session = get_chat_session_by_id(
chat_session_id=thread_id,
user_id=user.id if user else None,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
# Create a new "run" (chat message) in the session
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message="",
prompt_id=chat_session.persona.prompts[0].id,
token_count=0,
message_type=MessageType.ASSISTANT,
db_session=db_session,
commit=False,
)
db_session.flush()
latest_message.latest_child_message = new_message.id
db_session.commit()
# Schedule the background task
background_tasks.add_task(
process_run_in_background,
new_message.id,
latest_message.id,
chat_session.id,
run_request.assistant_id,
run_request.instructions or "",
run_request.tools or [],
user,
db_session,
)
return RunResponse(
id=str(new_message.id),
object="thread.run",
created_at=int(new_message.time_sent.timestamp()),
assistant_id=run_request.assistant_id,
thread_id=chat_session.id,
status="queued",
model=run_request.model or "default_model",
instructions=run_request.instructions or "",
tools=run_request.tools or [],
file_ids=[],
metadata=run_request.metadata,
)
@router.get("/threads/{thread_id}/runs/{run_id}")
def retrieve_run(
thread_id: UUID,
run_id: str,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
# Retrieve the chat message (which represents a "run" in DAnswer)
chat_message = get_chat_message(
chat_message_id=int(run_id), # Convert string run_id to int
user_id=user.id if user else None,
db_session=db_session,
)
if not chat_message:
raise HTTPException(status_code=404, detail="Run not found")
chat_session = chat_message.chat_session
# Map DAnswer status to OpenAI status
run_status: RunStatus = "queued"
if chat_message.message:
run_status = "in_progress"
if chat_message.token_count != 0:
run_status = "completed"
if chat_message.error:
run_status = "cancelled"
return RunResponse(
id=run_id,
object="thread.run",
created_at=int(chat_message.time_sent.timestamp()),
assistant_id=chat_session.persona_id or 0,
thread_id=chat_session.id,
status=run_status,
started_at=int(chat_message.time_sent.timestamp()),
completed_at=(
int(chat_message.time_sent.timestamp()) if chat_message.message else None
),
model=chat_session.current_alternate_model or "default_model",
instructions="", # DAnswer doesn't store per-message instructions
tools=[], # DAnswer doesn't have a direct equivalent for tools
file_ids=(
[file["id"] for file in chat_message.files] if chat_message.files else []
),
metadata=None, # DAnswer doesn't store metadata for individual messages
)
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
def cancel_run(
thread_id: UUID,
run_id: str,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
# In DAnswer, we don't have a direct equivalent to cancelling a run
# We'll simulate it by marking the message as "cancelled"
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
)
if not chat_message:
raise HTTPException(status_code=404, detail="Run not found")
chat_message.error = "Cancelled"
db_session.commit()
return retrieve_run(thread_id, run_id, user, db_session)
@router.get("/threads/{thread_id}/runs")
def list_runs(
thread_id: UUID,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[RunResponse]:
# In DAnswer, we'll treat each message in a chat session as a "run"
chat_messages = get_chat_messages_by_session(
chat_session_id=thread_id,
user_id=user.id if user else None,
db_session=db_session,
)
# Apply pagination
if after:
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
if before:
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
# Apply ordering
chat_messages = sorted(
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
)
# Apply limit
chat_messages = chat_messages[:limit]
return [
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
]
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
def list_run_steps(
run_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[dict]: # You may want to create a specific model for run steps
# DAnswer doesn't have an equivalent to run steps
# We'll return an empty list to maintain API compatibility
return []
# Additional helper functions can be added here if needed

View File

@@ -1,156 +0,0 @@
from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.chat import create_chat_session
from danswer.db.chat import delete_chat_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_chat_sessions_by_user
from danswer.db.chat import update_chat_session
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.server.query_and_chat.models import ChatSessionDetails
from danswer.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter(prefix="/threads")
# Models
class Thread(BaseModel):
id: UUID
object: str = "thread"
created_at: int
metadata: Optional[dict[str, str]] = None
class CreateThreadRequest(BaseModel):
messages: Optional[list[dict]] = None
metadata: Optional[dict[str, str]] = None
class ModifyThreadRequest(BaseModel):
metadata: Optional[dict[str, str]] = None
# API Endpoints
@router.post("")
def create_thread(
request: CreateThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave the naming till later to prevent delay
user_id=user_id,
persona_id=0,
)
return Thread(
id=new_chat_session.id,
created_at=int(new_chat_session.time_created.timestamp()),
metadata=request.metadata,
)
@router.get("/{thread_id}")
def retrieve_thread(
thread_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=thread_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return Thread(
id=chat_session.id,
created_at=int(chat_session.time_created.timestamp()),
metadata=None, # Assuming we don't store metadata in our current implementation
)
@router.post("/{thread_id}")
def modify_thread(
thread_id: UUID,
request: ModifyThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
try:
chat_session = update_chat_session(
db_session=db_session,
user_id=user_id,
chat_session_id=thread_id,
description=None, # Not updating description
sharing_status=None, # Not updating sharing status
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return Thread(
id=chat_session.id,
created_at=int(chat_session.time_created.timestamp()),
metadata=request.metadata,
)
@router.delete("/{thread_id}")
def delete_thread(
thread_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict:
user_id = user.id if user else None
try:
delete_chat_session(
user_id=user_id,
chat_session_id=thread_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
@router.get("")
def list_threads(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
user_id = user.id if user else None
chat_sessions = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
)
return ChatSessionsResponse(
sessions=[
ChatSessionDetails(
id=chat.id,
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
)
for chat in chat_sessions
]
)

View File

@@ -347,6 +347,7 @@ def handle_new_chat_message(
for packet in stream_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),

View File

@@ -108,9 +108,6 @@ class CreateChatMessageRequest(ChunkContext):
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# used for "OpenAI Assistants API"
existing_assistant_message_id: int | None = None
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None

View File

@@ -59,9 +59,7 @@ from shared_configs.model_server_models import SupportedEmbeddingModel
logger = setup_logger()
def setup_danswer(
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
) -> None:
def setup_danswer(db_session: Session, tenant_id: str | None) -> None:
"""
Setup Danswer for a particular tenant. In the Single Tenant case, it will set it up for the default schema
on server startup. In the MT case, it will be called when the tenant is created.
@@ -150,7 +148,7 @@ def setup_danswer(
# update multipass indexing setting based on GPU availability
update_default_multipass_indexing(db_session)
seed_initial_documents(db_session, tenant_id, cohere_enabled)
seed_initial_documents(db_session, tenant_id)
def translate_saved_search_settings(db_session: Session) -> None:

View File

@@ -1,255 +0,0 @@
from typing import cast
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
"""Helper function to get image generation LLM config based on available providers"""
if llm and llm.config.api_key and llm.config.model_provider == "openai":
return LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
return LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
# Fallback to checking for OpenAI provider in database
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError("Image generation tool requires an OpenAI API key")
return LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
class SearchToolConfig(BaseModel):
answer_style_config: AnswerStyleConfig = Field(
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
)
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
selected_sections: list[InferenceSection] | None = None
chunks_above: int = 0
chunks_below: int = 0
full_doc: bool = False
latest_query_files: list[InMemoryChatFile] | None = None
class InternetSearchToolConfig(BaseModel):
answer_style_config: AnswerStyleConfig = Field(
default_factory=lambda: AnswerStyleConfig(
citation_config=CitationConfig(all_docs_useful=True)
)
)
class ImageGenerationToolConfig(BaseModel):
additional_headers: dict[str, str] | None = None
class CustomToolConfig(BaseModel):
chat_session_id: UUID | None = None
message_id: int | None = None
additional_headers: dict[str, str] | None = None
def construct_tools(
persona: Persona,
prompt_config: PromptConfig,
db_session: Session,
user: User | None,
llm: LLM,
fast_llm: LLM,
search_tool_config: SearchToolConfig | None = None,
internet_search_tool_config: InternetSearchToolConfig | None = None,
image_generation_tool_config: ImageGenerationToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
) -> dict[int, list[Tool]]:
"""Constructs tools based on persona configuration and available APIs"""
tool_dict: dict[int, list[Tool]] = {}
for db_tool_model in persona.tools:
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
# Handle Search Tool
if tool_cls.__name__ == SearchTool.__name__:
if not search_tool_config:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
# Handle Image Generation Tool
elif tool_cls.__name__ == ImageGenerationTool.__name__:
if not image_generation_tool_config:
image_generation_tool_config = ImageGenerationToolConfig()
img_generation_llm_config = _get_image_generation_config(
llm, db_session
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=image_generation_tool_config.additional_headers,
model=img_generation_llm_config.model_name,
)
]
# Handle Internet Search Tool
elif tool_cls.__name__ == InternetSearchTool.__name__:
if not internet_search_tool_config:
internet_search_tool_config = InternetSearchToolConfig()
if not BING_API_KEY:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=BING_API_KEY,
answer_style_config=internet_search_tool_config.answer_style_config,
prompt_config=prompt_config,
)
]
# Handle custom tools
elif db_tool_model.openapi_schema:
if not custom_tool_config:
custom_tool_config = CustomToolConfig()
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=custom_tool_config.chat_session_id,
message_id=custom_tool_config.message_id,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_config.additional_headers or {}
)
),
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
if search_tool_config:
search_tool_config.document_pruning_config.tool_num_tokens = (
compute_all_tool_tokens(
tools,
get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
),
)
)
search_tool_config.document_pruning_config.using_tool_message = (
explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
)
return tool_dict

View File

@@ -1,6 +1,5 @@
import functools
import importlib
import inspect
from typing import Any
from typing import TypeVar
@@ -140,19 +139,8 @@ def fetch_ee_implementation_or_noop(
Exception: If EE is enabled but the fetch fails.
"""
if not global_version.is_ee_version():
if inspect.iscoroutinefunction(noop_return_value):
return lambda *args, **kwargs: noop_return_value
async def async_noop(*args: Any, **kwargs: Any) -> Any:
return await noop_return_value(*args, **kwargs)
return async_noop
else:
def sync_noop(*args: Any, **kwargs: Any) -> Any:
return noop_return_value
return sync_noop
try:
return fetch_versioned_implementation(module, attribute)
except Exception as e:

View File

@@ -4,7 +4,6 @@ import uuid
import aiohttp # Async HTTP client
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.auth.users import exceptions
@@ -14,8 +13,6 @@ from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.models import UserTenantMapping
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
@@ -105,19 +102,9 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_danswer(db_session, tenant_id, cohere_enabled=cohere_enabled)
add_users_to_tenant([email], tenant_id)
except Exception as e:
@@ -213,51 +200,11 @@ def configure_default_api_keys(db_session: Session) -> None:
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
try:
logger.info("Attempting to upsert Cohere cloud embedding provider")
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
logger.info("Successfully upserted Cohere cloud embedding provider")
logger.info("Updating search settings with Cohere embedding model details")
query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.FUTURE)
.order_by(SearchSettings.id.desc())
)
result = db_session.execute(query)
current_search_settings = result.scalars().first()
if current_search_settings:
current_search_settings.model_name = (
"embed-english-v3.0" # Cohere's latest model as of now
)
current_search_settings.model_dim = (
1024 # Cohere's embed-english-v3.0 dimension
)
current_search_settings.provider_type = EmbeddingProvider.COHERE
current_search_settings.index_name = (
"danswer_chunk_cohere_embed_english_v3_0"
)
current_search_settings.query_prefix = ""
current_search_settings.passage_prefix = ""
db_session.commit()
else:
raise RuntimeError(
"No search settings specified, DB is not in a valid state"
)
logger.info("Fetching updated search settings to verify changes")
updated_query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.PRESENT)
.order_by(SearchSettings.id.desc())
)
updated_result = db_session.execute(updated_query)
updated_result.scalars().first()
except Exception:
logger.exception("Failed to configure Cohere embedding provider")
except Exception as e:
logger.error(f"Failed to configure Cohere embedding provider: {e}")
else:
logger.info(
logger.error(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)

View File

@@ -142,20 +142,6 @@ async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"
ALLOWED_SLACK_BOT_TENANT_IDS = os.environ.get("ALLOWED_SLACK_BOT_TENANT_IDS")
DISALLOWED_SLACK_BOT_TENANT_LIST = (
[tenant.strip() for tenant in ALLOWED_SLACK_BOT_TENANT_IDS.split(",")]
if ALLOWED_SLACK_BOT_TENANT_IDS
else None
)
IGNORED_SYNCING_TENANT_IDS = os.environ.get("IGNORED_SYNCING_TENANT_ID")
IGNORED_SYNCING_TENANT_LIST = (
[tenant.strip() for tenant in IGNORED_SYNCING_TENANT_IDS.split(",")]
if IGNORED_SYNCING_TENANT_IDS
else None
)
SUPPORTED_EMBEDDING_MODELS = [
# Cloud-based models
SupportedEmbeddingModel(

View File

@@ -13,14 +13,6 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
DOMAIN = "test.com"
DEFAULT_PASSWORD = "test"
def build_email(name: str) -> str:
return f"{name}@test.com"
class UserManager:
@staticmethod
def create(
@@ -31,9 +23,9 @@ class UserManager:
name = f"test{str(uuid4())}"
if email is None:
email = build_email(name)
email = f"{name}@test.com"
password = DEFAULT_PASSWORD
password = "test"
body = {
"email": email,

View File

@@ -1,55 +0,0 @@
from typing import Optional
from uuid import UUID
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create("admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
)
)
except Exception:
pass
return None
@pytest.fixture
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)
@pytest.fixture
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
# Create a thread to use in the tests
response = requests.post(
f"{BASE_URL}/threads", # Updated endpoint path
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return UUID(response.json()["id"])

View File

@@ -1,151 +0,0 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
def test_create_assistant(admin_user: DATestUser | None) -> None:
response = requests.post(
ASSISTANTS_URL,
json={
"model": "gpt-3.5-turbo",
"name": "Test Assistant",
"description": "A test assistant",
"instructions": "You are a helpful assistant.",
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Test Assistant"
assert data["description"] == "A test assistant"
assert data["model"] == "gpt-3.5-turbo"
assert data["instructions"] == "You are a helpful assistant."
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, retrieve the assistant
response = requests.get(
f"{ASSISTANTS_URL}/{assistant_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == "Retrieve Test"
def test_modify_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, modify the assistant
response = requests.post(
f"{ASSISTANTS_URL}/{assistant_id}",
json={"name": "Modified Assistant", "instructions": "New instructions"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == "Modified Assistant"
assert data["instructions"] == "New instructions"
def test_delete_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, delete the assistant
response = requests.delete(
f"{ASSISTANTS_URL}/{assistant_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["deleted"] is True
def test_list_assistants(admin_user: DATestUser | None) -> None:
# Create multiple assistants
for i in range(3):
requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the assistants
response = requests.get(
ASSISTANTS_URL,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert len(data["data"]) >= 3 # At least the 3 we just created
assert all(assistant["object"] == "assistant" for assistant in data["data"])
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
# Create 5 assistants
for i in range(5):
requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# List assistants with limit
response = requests.get(
f"{ASSISTANTS_URL}?limit=2",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
assert data["has_more"] is True
# Get next page
before = data["last_id"]
response = requests.get(
f"{ASSISTANTS_URL}?limit=2&before={before}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
non_existent_id = -99
response = requests.get(
f"{ASSISTANTS_URL}/{non_existent_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@@ -1,133 +0,0 @@
import uuid
from typing import Optional
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
@pytest.fixture
def thread_id(admin_user: Optional[DATestUser]) -> str:
response = requests.post(
BASE_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return response.json()["id"]
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
response = requests.post(
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
json={
"role": "user",
"content": "Hello, world!",
"file_ids": [],
"metadata": {"key": "value"},
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["thread_id"] == thread_id
assert response_json["role"] == "user"
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
assert response_json["metadata"] == {"key": "value"}
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the messages
response = requests.get(
f"{BASE_URL}/{thread_id}/messages",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["object"] == "list"
assert isinstance(response_json["data"], list)
assert len(response_json["data"]) > 0
assert "first_id" in response_json
assert "last_id" in response_json
assert "has_more" in response_json
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
create_response = requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
message_id = create_response.json()["id"]
# Now, retrieve the message
response = requests.get(
f"{BASE_URL}/{thread_id}/messages/{message_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["id"] == message_id
assert response_json["thread_id"] == thread_id
assert response_json["role"] == "user"
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
create_response = requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
message_id = create_response.json()["id"]
# Now, modify the message
response = requests.post(
f"{BASE_URL}/{thread_id}/messages/{message_id}",
json={"metadata": {"new_key": "new_value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["id"] == message_id
assert response_json["thread_id"] == thread_id
assert response_json["metadata"] == {"new_key": "new_value"}
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
non_existent_thread_id = str(uuid.uuid4())
non_existent_message_id = -99
# Test with non-existent thread
response = requests.post(
f"{BASE_URL}/{non_existent_thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404
# Test with non-existent message
response = requests.get(
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@@ -1,137 +0,0 @@
from uuid import UUID
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
"""Create a run and return its ID."""
response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return response.json()["id"]
def test_create_run(
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
) -> None:
response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
"model": "gpt-3.5-turbo",
"instructions": "Test instructions",
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["object"] == "thread.run"
assert "created_at" in response_json
assert response_json["assistant_id"] == 0
assert UUID(response_json["thread_id"]) == thread_id
assert response_json["status"] == "queued"
assert response_json["model"] == "gpt-3.5-turbo"
assert response_json["instructions"] == "Test instructions"
def test_retrieve_run(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
retrieve_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert retrieve_response.status_code == 200
response_json = retrieve_response.json()
assert response_json["id"] == run_id
assert response_json["object"] == "thread.run"
assert "created_at" in response_json
assert UUID(response_json["thread_id"]) == thread_id
def test_cancel_run(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
cancel_response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert cancel_response.status_code == 200
response_json = cancel_response.json()
assert response_json["id"] == run_id
assert response_json["status"] == "cancelled"
def test_list_runs(
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
) -> None:
# Create a few runs
for _ in range(3):
requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the runs
list_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert list_response.status_code == 200
response_json = list_response.json()
assert isinstance(response_json, list)
assert len(response_json) >= 3
for run in response_json:
assert "id" in run
assert run["object"] == "thread.run"
assert "created_at" in run
assert UUID(run["thread_id"]) == thread_id
assert "status" in run
assert "model" in run
def test_list_run_steps(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
steps_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert steps_response.status_code == 200
response_json = steps_response.json()
assert isinstance(response_json, list)
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
assert len(response_json) == 0

View File

@@ -1,132 +0,0 @@
from uuid import UUID
import requests
from danswer.db.models import ChatSessionSharedStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
def test_create_thread(admin_user: DATestUser | None) -> None:
response = requests.post(
THREADS_URL,
json={"messages": None, "metadata": {"key": "value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["object"] == "thread"
assert "created_at" in response_json
assert response_json["metadata"] == {"key": "value"}
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, retrieve the thread
retrieve_response = requests.get(
f"{THREADS_URL}/{thread_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert retrieve_response.status_code == 200
response_json = retrieve_response.json()
assert response_json["id"] == thread_id
assert response_json["object"] == "thread"
assert "created_at" in response_json
def test_modify_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, modify the thread
modify_response = requests.post(
f"{THREADS_URL}/{thread_id}",
json={"metadata": {"new_key": "new_value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert modify_response.status_code == 200
response_json = modify_response.json()
assert response_json["id"] == thread_id
assert response_json["metadata"] == {"new_key": "new_value"}
def test_delete_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, delete the thread
delete_response = requests.delete(
f"{THREADS_URL}/{thread_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert delete_response.status_code == 200
response_json = delete_response.json()
assert response_json["id"] == thread_id
assert response_json["object"] == "thread.deleted"
assert response_json["deleted"] is True
def test_list_threads(admin_user: DATestUser | None) -> None:
# Create a few threads
for _ in range(3):
requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the threads
list_response = requests.get(
THREADS_URL,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert list_response.status_code == 200
response_json = list_response.json()
assert "sessions" in response_json
assert len(response_json["sessions"]) >= 3
for session in response_json["sessions"]:
assert "id" in session
assert "name" in session
assert "persona_id" in session
assert "time_created" in session
assert "shared_status" in session
assert "folder_id" in session
assert "current_alternate_model" in session
# Validate UUID
UUID(session["id"])
# Validate shared_status
assert session["shared_status"] in [
status.value for status in ChatSessionSharedStatus
]

View File

@@ -9,11 +9,12 @@ spec:
scaleTargetRef:
name: celery-worker-indexing
minReplicaCount: 1
maxReplicaCount: 30
maxReplicaCount: 10
triggers:
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing
@@ -21,10 +22,10 @@ spec:
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing:2
@@ -35,6 +36,7 @@ spec:
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing:3
@@ -42,12 +44,3 @@ spec:
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: cpu
metadata:
type: Utilization
value: "70"
- type: memory
metadata:
type: Utilization
value: "70"

View File

@@ -8,11 +8,12 @@ metadata:
spec:
scaleTargetRef:
name: celery-worker-light
minReplicaCount: 5
minReplicaCount: 1
maxReplicaCount: 20
triggers:
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync
@@ -22,6 +23,7 @@ spec:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync:2
@@ -31,6 +33,7 @@ spec:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync:3
@@ -40,6 +43,7 @@ spec:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_deletion
@@ -49,6 +53,7 @@ spec:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_deletion:2

View File

@@ -15,6 +15,7 @@ spec:
triggers:
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery
@@ -25,6 +26,7 @@ spec:
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:1
@@ -34,6 +36,7 @@ spec:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:2
@@ -43,6 +46,7 @@ spec:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:3
@@ -52,6 +56,7 @@ spec:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: periodic_tasks
@@ -61,6 +66,7 @@ spec:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: periodic_tasks:2

View File

@@ -1,19 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: indexing-model-server-scaledobject
namespace: danswer
labels:
app: indexing-model-server
spec:
scaleTargetRef:
name: indexing-model-server-deployment
pollingInterval: 15 # Check every 15 seconds
cooldownPeriod: 30 # Wait 30 seconds before scaling down
minReplicaCount: 1
maxReplicaCount: 14
triggers:
- type: cpu
metadata:
type: Utilization
value: "70"

View File

@@ -5,5 +5,5 @@ metadata:
namespace: danswer
type: Opaque
data:
host: { base64 encoded host here }
password: { base64 encoded password here }
host: { { base64-encoded-hostname } }
password: { { base64-encoded-password } }

View File

@@ -14,8 +14,8 @@ spec:
spec:
containers:
- name: celery-beat
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
imagePullPolicy: IfNotPresent
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
imagePullPolicy: Always
command:
[
"celery",

View File

@@ -14,8 +14,8 @@ spec:
spec:
containers:
- name: celery-worker-heavy
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
imagePullPolicy: IfNotPresent
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
imagePullPolicy: Always
command:
[
"celery",

View File

@@ -14,8 +14,8 @@ spec:
spec:
containers:
- name: celery-worker-indexing
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
imagePullPolicy: IfNotPresent
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
imagePullPolicy: Always
command:
[
"celery",
@@ -47,10 +47,10 @@ spec:
resources:
requests:
cpu: "500m"
memory: "4Gi"
memory: "1Gi"
limits:
cpu: "1000m"
memory: "8Gi"
memory: "2Gi"
volumes:
- name: vespa-certificates
secret:

View File

@@ -14,8 +14,8 @@ spec:
spec:
containers:
- name: celery-worker-light
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
imagePullPolicy: IfNotPresent
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
imagePullPolicy: Always
command:
[
"celery",

View File

@@ -14,8 +14,8 @@ spec:
spec:
containers:
- name: celery-worker-primary
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
imagePullPolicy: IfNotPresent
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
imagePullPolicy: Always
command:
[
"celery",

View File

@@ -1,125 +0,0 @@
import argparse
import os
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from openai import OpenAI
ASSISTANT_NAME = "Topic Analyzer"
SYSTEM_PROMPT = """
You are a helpful assistant that analyzes topics by searching through available \
documents and providing insights. These available documents come from common \
workplace tools like Slack, emails, Confluence, Google Drive, etc.
When analyzing a topic:
1. Search for relevant information using the search tool
2. Synthesize the findings into clear insights
3. Highlight key trends, patterns, or notable developments
4. Maintain objectivity and cite sources where relevant
"""
USER_PROMPT = """
Please analyze and provide insights about this topic: {topic}.
IMPORTANT: do not mention things that are not relevant to the specified topic. \
If there is no relevant information, just say "No relevant information found."
"""
def wait_on_run(client: OpenAI, run, thread):
while run.status == "queued" or run.status == "in_progress":
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run
def show_response(messages) -> None:
# Get only the assistant's response text
for message in messages.data[::-1]:
if message.role == "assistant":
for content in message.content:
if content.type == "text":
print(content.text)
break
def analyze_topics(topics: list[str]) -> None:
openai_api_key = os.environ.get(
"OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"
)
danswer_api_key = os.environ.get(
"DANSWER_API_KEY", "<your Danswer API key if not set as env var>"
)
client = OpenAI(
api_key=openai_api_key,
base_url="http://localhost:8080/openai-assistants",
default_headers={
"Authorization": f"Bearer {danswer_api_key}",
},
)
# Create an assistant if it doesn't exist
try:
assistants = client.beta.assistants.list(limit=100)
# Find the Topic Analyzer assistant if it exists
assistant = next((a for a in assistants.data if a.name == ASSISTANT_NAME))
client.beta.assistants.delete(assistant.id)
except Exception:
pass
assistant = client.beta.assistants.create(
name=ASSISTANT_NAME,
instructions=SYSTEM_PROMPT,
tools=[{"type": "SearchTool"}], # type: ignore
model="gpt-4o",
)
# Process each topic individually
for topic in topics:
thread = client.beta.threads.create()
message = client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=USER_PROMPT.format(topic=topic),
)
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
tools=[
{ # type: ignore
"type": "SearchTool",
"retrieval_details": {
"run_search": "always",
"filters": {
"time_cutoff": str(
datetime.now(timezone.utc) - timedelta(days=7)
)
},
},
}
],
)
run = wait_on_run(client, run, thread)
messages = client.beta.threads.messages.list(
thread_id=thread.id, order="asc", after=message.id
)
print(f"\nAnalysis for topic: {topic}")
print("-" * 40)
show_response(messages)
print()
# Example usage
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Analyze specific topics")
parser.add_argument("topics", nargs="+", help="Topics to analyze (one or more)")
args = parser.parse_args()
analyze_topics(args.topics)

View File

@@ -68,7 +68,7 @@ export function IndexAttemptStatus({
);
} else if (status === "in_progress") {
badge = (
<Badge variant="in_progress" icon={FiClock}>
<Badge className="flex-none" variant="in_progress" icon={FiClock}>
In Progress
</Badge>
);