mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
3 Commits
cohere_def
...
nit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c68602f456 | ||
|
|
9d57f34c34 | ||
|
|
cc2f584321 |
@@ -288,15 +288,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
||||
@@ -23,56 +23,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
||||
@@ -12,7 +12,6 @@ from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
@@ -73,15 +72,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
IGNORED_SYNCING_TENANT_LIST
|
||||
and tenant_id in IGNORED_SYNCING_TENANT_LIST
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
@@ -82,11 +81,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
|
||||
96
backend/danswer/background/celery/apps/scheduler.py
Normal file
96
backend/danswer/background/celery/apps/scheduler.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=1)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
super().setup_schedule()
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reloading schedule to check for new tenants...")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Checking for tenant task updates...")
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = getattr(self, "_store", {"entries": {}}).get(
|
||||
"entries", {}
|
||||
)
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name in current_schedule.keys():
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Found new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Updating schedule",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
if not hasattr(self, "_store"):
|
||||
self._store: dict[str, dict] = {"entries": {}}
|
||||
self.update_from_dict(new_beat_schedule)
|
||||
logger.info(f"New schedule: {new_beat_schedule}")
|
||||
|
||||
logger.info("Tenant tasks updated successfully")
|
||||
else:
|
||||
logger.debug("No schedule updates needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
current_tasks = set(current_schedule.keys())
|
||||
new_tasks = set(new_schedule.keys())
|
||||
return current_tasks != new_tasks
|
||||
@@ -8,7 +8,7 @@ tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
@@ -20,13 +20,13 @@ tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -29,26 +29,18 @@ JobStatusType = (
|
||||
def _initializer(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
||||
"""Ensure the parent proc's database connections are not touched
|
||||
in the new connection pool
|
||||
|
||||
Based on SQLAlchemy's recommendations to handle multiprocessing:
|
||||
Based on the recommended approach in the SQLAlchemy docs found:
|
||||
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
|
||||
# Reset the engine in the child process
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
# Optionally set a custom app name for database logging purposes
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
|
||||
# Initialize a new engine with desired parameters
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
|
||||
# Proceed with executing the target function
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -19,10 +19,16 @@ from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -35,6 +41,7 @@ from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
@@ -54,13 +61,14 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
@@ -69,14 +77,14 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_constructor import construct_tools
|
||||
from danswer.tools.tool_constructor import CustomToolConfig
|
||||
from danswer.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from danswer.tools.tool_constructor import InternetSearchToolConfig
|
||||
from danswer.tools.tool_constructor import SearchToolConfig
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
@@ -87,6 +95,9 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
@@ -111,6 +122,9 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@@ -281,6 +295,7 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -292,9 +307,6 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
@@ -416,20 +428,12 @@ def stream_chat_message_objects(
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if existing_assistant_message_id is None:
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
else:
|
||||
if final_msg.id != existing_assistant_message_id:
|
||||
raise RuntimeError(
|
||||
"The last message was not the existing assistant message. "
|
||||
f"Final message id: {final_msg.id}, "
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
@@ -500,19 +504,13 @@ def stream_chat_message_objects(
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
|
||||
# we don't need to reserve a message id if we're using an existing assistant message
|
||||
reserved_message_id = (
|
||||
final_msg.id
|
||||
if existing_assistant_message_id is not None
|
||||
else reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
@@ -527,13 +525,7 @@ def stream_chat_message_objects(
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
@@ -545,7 +537,6 @@ def stream_chat_message_objects(
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
@@ -569,39 +560,142 @@ def stream_chat_message_objects(
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_style_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
elif (
|
||||
llm.config.model_provider == "azure"
|
||||
and AZURE_DALLE_API_KEY is not None
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=bing_api_key,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
@@ -777,6 +871,7 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@@ -784,11 +879,9 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None
|
||||
),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
@@ -822,6 +915,7 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -831,6 +925,7 @@ def stream_chat_message(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
|
||||
@@ -55,11 +55,11 @@ def validate_channel_names(
|
||||
# Scaling configurations for multi-tenant Slack bot handling
|
||||
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
|
||||
TENANT_HEARTBEAT_INTERVAL = (
|
||||
15 # How often pods send heartbeats to indicate they are still processing a tenant
|
||||
60 # How often pods send heartbeats to indicate they are still processing a tenant
|
||||
)
|
||||
TENANT_HEARTBEAT_EXPIRATION = (
|
||||
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
TENANT_ACQUISITION_INTERVAL = (
|
||||
60 # How often pods attempt to acquire unprocessed tenants
|
||||
)
|
||||
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
|
||||
|
||||
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))
|
||||
|
||||
@@ -75,7 +75,6 @@ from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -165,15 +164,9 @@ class SlackbotHandler:
|
||||
|
||||
def acquire_tenants(self) -> None:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
|
||||
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
):
|
||||
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
|
||||
continue
|
||||
|
||||
if tenant_id in self.tenant_ids:
|
||||
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
|
||||
continue
|
||||
@@ -197,9 +190,6 @@ class SlackbotHandler:
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}")
|
||||
self.tenant_ids.add(tenant_id)
|
||||
|
||||
for tenant_id in self.tenant_ids:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
@@ -246,14 +236,14 @@ class SlackbotHandler:
|
||||
|
||||
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
|
||||
|
||||
if self.socket_clients.get(tenant_id):
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
|
||||
self.start_socket_client(tenant_id, slack_bot_tokens)
|
||||
|
||||
except KvKeyNotFoundError:
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if self.socket_clients.get(tenant_id):
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
del self.socket_clients[tenant_id]
|
||||
del self.slack_bot_tokens[tenant_id]
|
||||
@@ -287,14 +277,14 @@ class SlackbotHandler:
|
||||
logger.info(f"Connecting socket client for tenant {tenant_id}")
|
||||
socket_client.connect()
|
||||
self.socket_clients[tenant_id] = socket_client
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
|
||||
|
||||
def stop_socket_clients(self) -> None:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
for tenant_id, client in self.socket_clients.items():
|
||||
if client:
|
||||
asyncio.run(client.close())
|
||||
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
|
||||
asyncio.run(client.close())
|
||||
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
|
||||
|
||||
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
|
||||
if not self.running:
|
||||
@@ -308,16 +298,6 @@ class SlackbotHandler:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
self.stop_socket_clients()
|
||||
|
||||
# Release locks for all tenants
|
||||
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in self.tenant_ids:
|
||||
try:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
|
||||
|
||||
# Wait for background threads to finish (with timeout)
|
||||
logger.info("Waiting for background threads to finish...")
|
||||
self.acquire_thread.join(timeout=5)
|
||||
|
||||
@@ -24,13 +24,6 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
||||
return tool
|
||||
|
||||
|
||||
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
|
||||
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
|
||||
if not tool:
|
||||
raise ValueError("Tool by specified name does not exist")
|
||||
return tool
|
||||
|
||||
|
||||
def create_tool(
|
||||
name: str,
|
||||
description: str | None,
|
||||
@@ -44,7 +37,7 @@ def create_tool(
|
||||
description=description,
|
||||
in_code_tool_id=None,
|
||||
openapi_schema=openapi_schema,
|
||||
custom_headers=[header.model_dump() for header in custom_headers]
|
||||
custom_headers=[header.dict() for header in custom_headers]
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
|
||||
@@ -74,9 +74,6 @@ from danswer.server.manage.search_settings import router as search_settings_rout
|
||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from danswer.server.manage.users import router as user_router
|
||||
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
)
|
||||
from danswer.server.query_and_chat.chat_backend import router as chat_router
|
||||
from danswer.server.query_and_chat.query_backend import (
|
||||
admin_router as admin_query_router,
|
||||
@@ -273,9 +270,6 @@ def get_application() -> FastAPI:
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
@@ -63,7 +63,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,44 +0,0 @@
|
||||
[
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/overview",
|
||||
"title": "Use Cases Overview",
|
||||
"content": "How to leverage Danswer in your organization\n\nDanswer Overview\nDanswer is the AI Assistant connected to your organization's docs, apps, and people. Danswer makes Generative AI more versatile for work by enabling new types of questions like \"What is the most common feature request we've heard from customers this month\". Whereas other AI systems have no context of your team and are generally unhelpful with work related questions, Danswer makes it possible to ask these questions in natural language and get back answers in seconds.\n\nDanswer can connect to +30 different tools and the use cases are not limited to the ones in the following pages. The highlighted use cases are for inspiration and come from feedback gathered from our users and customers.\n\n\nCommon Getting Started Questions:\n\nWhy are these docs connected in my Danswer deployment?\nAnswer: This is just an example of how connectors work in Danswer. You can connect up your own team's knowledge and you will be able to ask questions unique to your organization. Danswer will keep all of the knowledge up to date and in sync with your connected applications.\n\nIs my data being sent anywhere when I connect it up to Danswer?\nAnswer: No! Danswer is built with data security as our highest priority. We open sourced it so our users can know exactly what is going on with their data. By default all of the document processing happens within Danswer. The only time it is sent outward is for the GenAI call to generate answers.\n\nWhere is the feature for auto sync-ing document level access permissions from all connected sources?\nAnswer: This falls under the Enterprise Edition set of Danswer features built on top of the MIT/community edition. If you are on Danswer Cloud, you have access to them by default. If you're running it yourself, reach out to the Danswer team to receive access.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
|
||||
"title": "Enterprise Search",
|
||||
"content": "Value of Enterprise Search with Danswer\n\nWhat is Enterprise Search and why is it Important?\nAn Enterprise Search system gives team members a single place to access all of the disparate knowledge of an organization. Critical information is saved across a host of channels like call transcripts with prospects, engineering design docs, IT runbooks, customer support email exchanges, project management tickets, and more. As fast moving teams scale up, information gets spread out and more disorganized.\n\nSince it quickly becomes infeasible to check across every source, decisions get made on incomplete information, employee satisfaction decreases, and the most valuable members of your team are tied up with constant distractions as junior teammates are unable to unblock themselves. Danswer solves this problem by letting anyone on the team access all of the knowledge across your organization in a permissioned and secure way. Users can ask questions in natural language and get back answers and documents across all of the connected sources instantly.\n\nWhat's the real cost?\nA typical knowledge worker spends over 2 hours a week on search, but more than that, the cost of incomplete or incorrect information can be extremely high. Customer support/success that isn't able to find the reference to similar cases could cause hours or even days of delay leading to lower customer satisfaction or in the worst case - churn. An account exec not realizing that a prospect had previously mentioned a specific need could lead to lost deals. An engineer not realizing a similar feature had previously been built could result in weeks of wasted development time and tech debt with duplicate implementation. With a lack of knowledge, your whole organization is navigating in the dark - inefficient and mistake prone.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
|
||||
"title": "Enterprise Search",
|
||||
"content": "More than Search\nWhen analyzing the entire corpus of knowledge within your company is as easy as asking a question in a search bar, your entire team can stay informed and up to date. Danswer also makes it trivial to identify where knowledge is well documented and where it is lacking. Team members who are centers of knowledge can begin to effectively document their expertise since it is no longer being thrown into a black hole. All of this allows the organization to achieve higher efficiency and drive business outcomes.\n\nWith Generative AI, the entire user experience has evolved as well. For example, instead of just finding similar cases for your customer support team to reference, Danswer breaks down the issue and explains it so that even the most junior members can understand it. This in turn lets them give the most holistic and technically accurate response possible to your customers. On the other end, even the super stars of your sales team will not be able to review 10 hours of transcripts before hopping on that critical call, but Danswer can easily parse through it in mere seconds and give crucial context to help your team close.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/ai_platform",
|
||||
"title": "AI Platform",
|
||||
"content": "Build AI Agents powered by the knowledge and workflows specific to your organization.\n\nBeyond Answers\nAgents enabled by generative AI and reasoning capable models are helping teams to automate their work. Danswer is helping teams make it happen. Danswer provides out of the box user chat sessions, attaching custom tools, handling LLM reasoning, code execution, data analysis, referencing internal knowledge, and much more.\n\nDanswer as a platform is not a no-code agent builder. We are made by developers for developers and this gives your team the full flexibility and power to create agents not constrained by blocks and simple logic paths.\n\nFlexibility and Extensibility\nDanswer is open source and completely whitebox. This not only gives transparency to what happens within the system but also means that your team can directly modify the source code to suit your unique needs.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/customer_support",
|
||||
"title": "Customer Support",
|
||||
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nDanswer takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/sales",
|
||||
"title": "Sales",
|
||||
"content": "Keep your team up to date on every conversation and update so they can close.\n\nRecall Every Detail\nBeing able to instantly revisit every detail of any call without reading transcripts is helping Sales teams provide more tailored pitches, build stronger relationships, and close more deals. Instead of searching and reading through hours of transcripts in preparation for a call, your team can now ask Danswer \"What specific features was ACME interested in seeing for the demo\". Since your team doesn't have time to read every transcript prior to a call, Danswer provides a more thorough summary because it can instantly parse hundreds of pages and distill out the relevant information. Even for fast lookups it becomes much more convenient - for example to brush up on connection building topics by asking \"What rapport building topic did we chat about in the last call with ACME\".\n\nKnow Every Product Update\nIt is impossible for Sales teams to keep up with every product update. Because of this, when a prospect has a question that the Sales team does not know, they have no choice but to rely on the Product and Engineering orgs to get an authoritative answer. Not only is this distracting to the other teams, it also slows down the time to respond to the prospect (and as we know, time is the biggest killer of deals). With Danswer, it is even possible to get answers live on call because of how fast accessing information becomes. A question like \"Have we shipped the Microsoft AD integration yet?\" can now be answered in seconds meaning that prospects can get answers while on the call instead of asynchronously and sales cycles are reduced as a result.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/operations",
|
||||
"title": "Operations",
|
||||
"content": "Double the productivity of your Ops teams like IT, HR, etc.\n\nAutomatically Resolve Tickets\nModern teams are leveraging AI to auto-resolve up to 50% of tickets. Whether it is an employee asking about benefits details or how to set up the VPN for remote work, Danswer can help your team help themselves. This frees up your team to do the real impactful work of landing star candidates or improving your internal processes.\n\nAI Aided Onboarding\nOne of the periods where your team needs the most help is when they're just ramping up. Instead of feeling lost in dozens of new tools, Danswer gives them a single place where they can ask about anything in natural language. Whether it's how to set up their work environment or what their onboarding goals are, Danswer can walk them through every step with the help of Generative AI. This lets your team feel more empowered and gives time back to the more seasoned members of your team to focus on moving the needle.",
|
||||
"chunk_ind": 0
|
||||
}
|
||||
]
|
||||
@@ -3,7 +3,6 @@ import json
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from cohere import Client
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import default_public_access
|
||||
@@ -33,7 +32,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -92,9 +91,7 @@ def _create_indexable_chunks(
|
||||
return list(ids_to_documents.values()), chunks
|
||||
|
||||
|
||||
def seed_initial_documents(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
def seed_initial_documents(db_session: Session, tenant_id: str | None) -> None:
|
||||
"""
|
||||
Seed initial documents so users don't have an empty index to start
|
||||
|
||||
@@ -135,9 +132,7 @@ def seed_initial_documents(
|
||||
return
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL and not (
|
||||
search_settings.model_name == "embed-english-v3.0" and cohere_enabled
|
||||
):
|
||||
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL:
|
||||
logger.info("Embedding model has been updated, skipping")
|
||||
return
|
||||
|
||||
@@ -178,31 +173,10 @@ def seed_initial_documents(
|
||||
)
|
||||
cc_pair_id = cast(int, result.data)
|
||||
|
||||
if cohere_enabled:
|
||||
initial_docs_path = os.path.join(
|
||||
os.getcwd(), "danswer", "seeding", "initial_docs_cohere.json"
|
||||
)
|
||||
|
||||
cohere_client = Client(COHERE_DEFAULT_API_KEY)
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
for doc in processed_docs:
|
||||
title_embedding = cohere_client.embed(
|
||||
texts=[doc["title"]], model="embed-english-v3.0"
|
||||
).embeddings[0]
|
||||
content_embedding = cohere_client.embed(
|
||||
texts=[doc["content"]], model="embed-english-v3.0"
|
||||
).embeddings[0]
|
||||
doc["title_embedding"] = title_embedding
|
||||
doc["content_embedding"] = content_embedding
|
||||
|
||||
else:
|
||||
initial_docs_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"danswer",
|
||||
"seeding",
|
||||
"initial_docs.json",
|
||||
)
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
initial_docs_path = os.path.join(
|
||||
os.getcwd(), "danswer", "seeding", "initial_docs.json"
|
||||
)
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)
|
||||
|
||||
|
||||
@@ -1,273 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import get_personas
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
from danswer.db.tools import get_tool_by_name
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter(prefix="/assistants")
|
||||
|
||||
|
||||
# Base models
|
||||
class AssistantObject(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant"
|
||||
created_at: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
model: str
|
||||
instructions: Optional[str] = None
|
||||
tools: list[dict[str, Any]]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
model: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ModifyAssistantRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DeleteAssistantResponse(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant.deleted"
|
||||
deleted: bool
|
||||
|
||||
|
||||
class ListAssistantsResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: list[AssistantObject]
|
||||
first_id: Optional[int] = None
|
||||
last_id: Optional[int] = None
|
||||
has_more: bool
|
||||
|
||||
|
||||
def persona_to_assistant(persona: Persona) -> AssistantObject:
|
||||
return AssistantObject(
|
||||
id=persona.id,
|
||||
created_at=0,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
model=persona.llm_model_version_override or "gpt-3.5-turbo",
|
||||
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
|
||||
tools=[
|
||||
{
|
||||
"type": tool.display_name,
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"schema": tool.openapi_schema,
|
||||
},
|
||||
}
|
||||
for tool in persona.tools
|
||||
],
|
||||
file_ids=[], # Assuming no file support for now
|
||||
metadata={}, # Assuming no metadata for now
|
||||
)
|
||||
|
||||
|
||||
# API endpoints
|
||||
@router.post("")
|
||||
def create_assistant(
|
||||
request: CreateAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
prompt = None
|
||||
if request.instructions:
|
||||
prompt = upsert_prompt(
|
||||
user=user,
|
||||
name=f"Prompt for {request.name or 'New Assistant'}",
|
||||
description="Auto-generated prompt",
|
||||
system_prompt=request.instructions,
|
||||
task_prompt="",
|
||||
include_citations=True,
|
||||
datetime_aware=True,
|
||||
personas=[],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
tool_ids = []
|
||||
for tool in request.tools or []:
|
||||
tool_type = tool.get("type")
|
||||
if not tool_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_db = get_tool_by_name(tool_type, db_session)
|
||||
tool_ids.append(tool_db.id)
|
||||
except ValueError:
|
||||
# Skip tools that don't exist in the database
|
||||
logger.error(f"Tool {tool_type} not found in database")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tool {tool_type} not found in database"
|
||||
)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=request.name or f"Assistant-{uuid4()}",
|
||||
description=request.description or "",
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=request.model,
|
||||
starter_messages=None,
|
||||
is_public=False,
|
||||
db_session=db_session,
|
||||
prompt_ids=[prompt.id] if prompt else [0],
|
||||
document_set_ids=[],
|
||||
tool_ids=tool_ids,
|
||||
icon_color=None,
|
||||
icon_shape=None,
|
||||
is_visible=True,
|
||||
)
|
||||
|
||||
if prompt:
|
||||
prompt.personas = [persona]
|
||||
db_session.commit()
|
||||
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
""
|
||||
|
||||
|
||||
@router.get("/{assistant_id}")
|
||||
def retrieve_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
try:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
except ValueError:
|
||||
persona = None
|
||||
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.post("/{assistant_id}")
|
||||
def modify_assistant(
|
||||
assistant_id: int,
|
||||
request: ModifyAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=True,
|
||||
)
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(persona, key, value)
|
||||
|
||||
if "instructions" in update_data and persona.prompts:
|
||||
persona.prompts[0].system_prompt = update_data["instructions"]
|
||||
|
||||
db_session.commit()
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.delete("/{assistant_id}")
|
||||
def delete_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DeleteAssistantResponse:
|
||||
try:
|
||||
mark_persona_as_deleted(
|
||||
persona_id=int(assistant_id),
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return DeleteAssistantResponse(id=assistant_id, deleted=True)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_assistants(
|
||||
limit: int = Query(20, le=100),
|
||||
order: str = Query("desc", regex="^(asc|desc)$"),
|
||||
after: Optional[int] = None,
|
||||
before: Optional[int] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListAssistantsResponse:
|
||||
personas = list(
|
||||
get_personas(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
get_editable=False,
|
||||
joinedload_all=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
personas = [p for p in personas if p.id > int(after)]
|
||||
if before:
|
||||
personas = [p for p in personas if p.id < int(before)]
|
||||
|
||||
# Apply ordering
|
||||
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
personas = personas[:limit]
|
||||
|
||||
assistants = [persona_to_assistant(p) for p in personas]
|
||||
|
||||
return ListAssistantsResponse(
|
||||
data=assistants,
|
||||
first_id=assistants[0].id if assistants else None,
|
||||
last_id=assistants[-1].id if assistants else None,
|
||||
has_more=len(personas) == limit,
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.server.openai_assistants_api.asssistants_api import (
|
||||
router as assistants_router,
|
||||
)
|
||||
from danswer.server.openai_assistants_api.messages_api import router as messages_router
|
||||
from danswer.server.openai_assistants_api.runs_api import router as runs_router
|
||||
from danswer.server.openai_assistants_api.threads_api import router as threads_router
|
||||
|
||||
|
||||
def get_full_openai_assistants_api_router() -> APIRouter:
|
||||
router = APIRouter(prefix="/openai-assistants")
|
||||
|
||||
router.include_router(assistants_router)
|
||||
router.include_router(runs_router)
|
||||
router.include_router(threads_router)
|
||||
router.include_router(messages_router)
|
||||
|
||||
return router
|
||||
@@ -1,235 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
|
||||
router = APIRouter(prefix="")
|
||||
|
||||
|
||||
Role = Literal["user", "assistant"]
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
|
||||
object: Literal["thread.message"] = "thread.message"
|
||||
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
thread_id: str
|
||||
role: Role
|
||||
content: list[MessageContent]
|
||||
file_ids: list[str] = []
|
||||
assistant_id: Optional[str] = None
|
||||
run_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
|
||||
|
||||
|
||||
class CreateMessageRequest(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
file_ids: list[str] = []
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class ListMessagesResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: list[Message]
|
||||
first_id: str
|
||||
last_id: str
|
||||
has_more: bool
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages")
|
||||
def create_message(
|
||||
thread_id: str,
|
||||
message: CreateMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message=message.content,
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
token_count=check_number_of_tokens(message.content),
|
||||
message_type=(
|
||||
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return Message(
|
||||
id=str(new_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user",
|
||||
content=[MessageContent(type="text", text=message.content)],
|
||||
file_ids=message.file_ids,
|
||||
metadata=message.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages")
|
||||
def list_messages(
|
||||
thread_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListMessagesResponse:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
messages = [m for m in messages if str(m.id) >= after]
|
||||
if before:
|
||||
messages = [m for m in messages if str(m.id) <= before]
|
||||
|
||||
# Apply ordering
|
||||
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
messages = messages[:limit]
|
||||
|
||||
data = [
|
||||
Message(
|
||||
id=str(m.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if m.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=m.message)],
|
||||
created_at=int(m.time_sent.timestamp()),
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
return ListMessagesResponse(
|
||||
data=data,
|
||||
first_id=str(data[0].id) if data else "",
|
||||
last_id=str(data[-1].id) if data else "",
|
||||
has_more=len(messages) == limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages/{message_id}")
|
||||
def retrieve_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
class ModifyMessageRequest(BaseModel):
|
||||
metadata: dict
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages/{message_id}")
|
||||
def modify_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
request: ModifyMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Update metadata
|
||||
# TODO: Uncomment this once we have metadata in the chat message
|
||||
# chat_message.metadata = request.metadata
|
||||
# db_session.commit()
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
@@ -1,344 +0,0 @@
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
assistant_id: int
|
||||
model: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
additional_instructions: Optional[str] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
RunStatus = Literal[
|
||||
"queued",
|
||||
"in_progress",
|
||||
"requires_action",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
"failed",
|
||||
"completed",
|
||||
"expired",
|
||||
]
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["thread.run"]
|
||||
created_at: int
|
||||
assistant_id: int
|
||||
thread_id: UUID
|
||||
status: RunStatus
|
||||
started_at: Optional[int] = None
|
||||
expires_at: Optional[int] = None
|
||||
cancelled_at: Optional[int] = None
|
||||
failed_at: Optional[int] = None
|
||||
completed_at: Optional[int] = None
|
||||
last_error: Optional[dict] = None
|
||||
model: str
|
||||
instructions: str
|
||||
tools: list[dict]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
def process_run_in_background(
|
||||
message_id: int,
|
||||
parent_message_id: int,
|
||||
chat_session_id: UUID,
|
||||
assistant_id: int,
|
||||
instructions: str,
|
||||
tools: list[dict],
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Get the latest message in the chat session
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_tool_retrieval_details = RetrievalDetails()
|
||||
for tool in tools:
|
||||
if tool["type"] == SearchTool.__name__ and (
|
||||
retrieval_details := tool.get("retrieval_details")
|
||||
):
|
||||
search_tool_retrieval_details = RetrievalDetails.model_validate(
|
||||
retrieval_details
|
||||
)
|
||||
break
|
||||
|
||||
new_msg_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=int(parent_message_id) if parent_message_id else None,
|
||||
message=instructions,
|
||||
file_descriptors=[],
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||
query_override=None,
|
||||
regenerate=None,
|
||||
llm_override=None,
|
||||
prompt_override=None,
|
||||
alternate_assistant_id=assistant_id,
|
||||
use_existing_user_message=True,
|
||||
existing_assistant_message_id=message_id,
|
||||
)
|
||||
|
||||
run_message = get_chat_message(message_id, user.id if user else None, db_session)
|
||||
try:
|
||||
for packet in stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
):
|
||||
if isinstance(packet, ChatMessageDetail):
|
||||
# Update the run status and message content
|
||||
run_message = get_chat_message(
|
||||
message_id, user.id if user else None, db_session
|
||||
)
|
||||
if run_message:
|
||||
# this handles cancelling
|
||||
if run_message.error:
|
||||
return
|
||||
|
||||
run_message.message = packet.message
|
||||
run_message.message_type = MessageType.ASSISTANT
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception("Error processing run in background")
|
||||
run_message.error = str(e)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
db_session.refresh(run_message)
|
||||
if run_message.token_count == 0:
|
||||
run_message.error = "No tokens generated"
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs")
|
||||
def create_run(
|
||||
thread_id: UUID,
|
||||
run_request: RunRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
# Create a new "run" (chat message) in the session
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message="",
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.flush()
|
||||
latest_message.latest_child_message = new_message.id
|
||||
db_session.commit()
|
||||
|
||||
# Schedule the background task
|
||||
background_tasks.add_task(
|
||||
process_run_in_background,
|
||||
new_message.id,
|
||||
latest_message.id,
|
||||
chat_session.id,
|
||||
run_request.assistant_id,
|
||||
run_request.instructions or "",
|
||||
run_request.tools or [],
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return RunResponse(
|
||||
id=str(new_message.id),
|
||||
object="thread.run",
|
||||
created_at=int(new_message.time_sent.timestamp()),
|
||||
assistant_id=run_request.assistant_id,
|
||||
thread_id=chat_session.id,
|
||||
status="queued",
|
||||
model=run_request.model or "default_model",
|
||||
instructions=run_request.instructions or "",
|
||||
tools=run_request.tools or [],
|
||||
file_ids=[],
|
||||
metadata=run_request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}")
|
||||
def retrieve_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# Retrieve the chat message (which represents a "run" in DAnswer)
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=int(run_id), # Convert string run_id to int
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_session = chat_message.chat_session
|
||||
|
||||
# Map DAnswer status to OpenAI status
|
||||
run_status: RunStatus = "queued"
|
||||
if chat_message.message:
|
||||
run_status = "in_progress"
|
||||
if chat_message.token_count != 0:
|
||||
run_status = "completed"
|
||||
if chat_message.error:
|
||||
run_status = "cancelled"
|
||||
|
||||
return RunResponse(
|
||||
id=run_id,
|
||||
object="thread.run",
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
assistant_id=chat_session.persona_id or 0,
|
||||
thread_id=chat_session.id,
|
||||
status=run_status,
|
||||
started_at=int(chat_message.time_sent.timestamp()),
|
||||
completed_at=(
|
||||
int(chat_message.time_sent.timestamp()) if chat_message.message else None
|
||||
),
|
||||
model=chat_session.current_alternate_model or "default_model",
|
||||
instructions="", # DAnswer doesn't store per-message instructions
|
||||
tools=[], # DAnswer doesn't have a direct equivalent for tools
|
||||
file_ids=(
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
),
|
||||
metadata=None, # DAnswer doesn't store metadata for individual messages
|
||||
)
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
|
||||
def cancel_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# In DAnswer, we don't have a direct equivalent to cancelling a run
|
||||
# We'll simulate it by marking the message as "cancelled"
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_message.error = "Cancelled"
|
||||
db_session.commit()
|
||||
|
||||
return retrieve_run(thread_id, run_id, user, db_session)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs")
|
||||
def list_runs(
|
||||
thread_id: UUID,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[RunResponse]:
|
||||
# In DAnswer, we'll treat each message in a chat session as a "run"
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
if after:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
|
||||
if before:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
|
||||
|
||||
# Apply ordering
|
||||
chat_messages = sorted(
|
||||
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
|
||||
)
|
||||
|
||||
# Apply limit
|
||||
chat_messages = chat_messages[:limit]
|
||||
|
||||
return [
|
||||
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
|
||||
def list_run_steps(
|
||||
run_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[dict]: # You may want to create a specific model for run steps
|
||||
# DAnswer doesn't have an equivalent to run steps
|
||||
# We'll return an empty list to maintain API compatibility
|
||||
return []
|
||||
|
||||
|
||||
# Additional helper functions can be added here if needed
|
||||
@@ -1,156 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import delete_chat_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_chat_sessions_by_user
|
||||
from danswer.db.chat import update_chat_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.server.query_and_chat.models import ChatSessionDetails
|
||||
from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter(prefix="/threads")
|
||||
|
||||
|
||||
# Models
|
||||
class Thread(BaseModel):
|
||||
id: UUID
|
||||
object: str = "thread"
|
||||
created_at: int
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
messages: Optional[list[dict]] = None
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class ModifyThreadRequest(BaseModel):
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
# API Endpoints
|
||||
@router.post("")
|
||||
def create_thread(
|
||||
request: CreateThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave the naming till later to prevent delay
|
||||
user_id=user_id,
|
||||
persona_id=0,
|
||||
)
|
||||
|
||||
return Thread(
|
||||
id=new_chat_session.id,
|
||||
created_at=int(new_chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
def retrieve_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=None, # Assuming we don't store metadata in our current implementation
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}")
|
||||
def modify_thread(
|
||||
thread_id: UUID,
|
||||
request: ModifyThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = update_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
description=None, # Not updating description
|
||||
sharing_status=None, # Not updating sharing status
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}")
|
||||
def delete_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_threads(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id if user else None
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
id=chat.id,
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
)
|
||||
@@ -347,6 +347,7 @@ def handle_new_chat_message(
|
||||
for packet in stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
|
||||
@@ -108,9 +108,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# used for "OpenAI Assistants API"
|
||||
existing_assistant_message_id: int | None = None
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@@ -59,9 +59,7 @@ from shared_configs.model_server_models import SupportedEmbeddingModel
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def setup_danswer(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
def setup_danswer(db_session: Session, tenant_id: str | None) -> None:
|
||||
"""
|
||||
Setup Danswer for a particular tenant. In the Single Tenant case, it will set it up for the default schema
|
||||
on server startup. In the MT case, it will be called when the tenant is created.
|
||||
@@ -150,7 +148,7 @@ def setup_danswer(
|
||||
# update multipass indexing setting based on GPU availability
|
||||
update_default_multipass_indexing(db_session)
|
||||
|
||||
seed_initial_documents(db_session, tenant_id, cohere_enabled)
|
||||
seed_initial_documents(db_session, tenant_id)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
|
||||
@@ -1,255 +0,0 @@
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
"""Helper function to get image generation LLM config based on available providers"""
|
||||
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||
return LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
|
||||
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
|
||||
return LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
|
||||
# Fallback to checking for OpenAI provider in database
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError("Image generation tool requires an OpenAI API key")
|
||||
|
||||
return LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
|
||||
|
||||
class SearchToolConfig(BaseModel):
|
||||
answer_style_config: AnswerStyleConfig = Field(
|
||||
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
|
||||
)
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
full_doc: bool = False
|
||||
latest_query_files: list[InMemoryChatFile] | None = None
|
||||
|
||||
|
||||
class InternetSearchToolConfig(BaseModel):
|
||||
answer_style_config: AnswerStyleConfig = Field(
|
||||
default_factory=lambda: AnswerStyleConfig(
|
||||
citation_config=CitationConfig(all_docs_useful=True)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationToolConfig(BaseModel):
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
class CustomToolConfig(BaseModel):
|
||||
chat_session_id: UUID | None = None
|
||||
message_id: int | None = None
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
prompt_config: PromptConfig,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
internet_search_tool_config: InternetSearchToolConfig | None = None,
|
||||
image_generation_tool_config: ImageGenerationToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
for db_tool_model in persona.tools:
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
|
||||
# Handle Search Tool
|
||||
if tool_cls.__name__ == SearchTool.__name__:
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
|
||||
# Handle Image Generation Tool
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
if not image_generation_tool_config:
|
||||
image_generation_tool_config = ImageGenerationToolConfig()
|
||||
|
||||
img_generation_llm_config = _get_image_generation_config(
|
||||
llm, db_session
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=image_generation_tool_config.additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Internet Search Tool
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
if not internet_search_tool_config:
|
||||
internet_search_tool_config = InternetSearchToolConfig()
|
||||
|
||||
if not BING_API_KEY:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=BING_API_KEY,
|
||||
answer_style_config=internet_search_tool_config.answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle custom tools
|
||||
elif db_tool_model.openapi_schema:
|
||||
if not custom_tool_config:
|
||||
custom_tool_config = CustomToolConfig()
|
||||
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=custom_tool_config.chat_session_id,
|
||||
message_id=custom_tool_config.message_id,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_config.additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
if search_tool_config:
|
||||
search_tool_config.document_pruning_config.tool_num_tokens = (
|
||||
compute_all_tool_tokens(
|
||||
tools,
|
||||
get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
),
|
||||
)
|
||||
)
|
||||
search_tool_config.document_pruning_config.using_tool_message = (
|
||||
explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
)
|
||||
|
||||
return tool_dict
|
||||
@@ -1,6 +1,5 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -140,19 +139,8 @@ def fetch_ee_implementation_or_noop(
|
||||
Exception: If EE is enabled but the fetch fails.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
if inspect.iscoroutinefunction(noop_return_value):
|
||||
return lambda *args, **kwargs: noop_return_value
|
||||
|
||||
async def async_noop(*args: Any, **kwargs: Any) -> Any:
|
||||
return await noop_return_value(*args, **kwargs)
|
||||
|
||||
return async_noop
|
||||
|
||||
else:
|
||||
|
||||
def sync_noop(*args: Any, **kwargs: Any) -> Any:
|
||||
return noop_return_value
|
||||
|
||||
return sync_noop
|
||||
try:
|
||||
return fetch_versioned_implementation(module, attribute)
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,7 +4,6 @@ import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import exceptions
|
||||
@@ -14,8 +13,6 @@ from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
@@ -105,19 +102,9 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session, tenant_id)
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_danswer(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
@@ -213,51 +200,11 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Attempting to upsert Cohere cloud embedding provider")
|
||||
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
|
||||
logger.info("Successfully upserted Cohere cloud embedding provider")
|
||||
|
||||
logger.info("Updating search settings with Cohere embedding model details")
|
||||
query = (
|
||||
select(SearchSettings)
|
||||
.where(SearchSettings.status == IndexModelStatus.FUTURE)
|
||||
.order_by(SearchSettings.id.desc())
|
||||
)
|
||||
result = db_session.execute(query)
|
||||
current_search_settings = result.scalars().first()
|
||||
|
||||
if current_search_settings:
|
||||
current_search_settings.model_name = (
|
||||
"embed-english-v3.0" # Cohere's latest model as of now
|
||||
)
|
||||
current_search_settings.model_dim = (
|
||||
1024 # Cohere's embed-english-v3.0 dimension
|
||||
)
|
||||
current_search_settings.provider_type = EmbeddingProvider.COHERE
|
||||
current_search_settings.index_name = (
|
||||
"danswer_chunk_cohere_embed_english_v3_0"
|
||||
)
|
||||
current_search_settings.query_prefix = ""
|
||||
current_search_settings.passage_prefix = ""
|
||||
db_session.commit()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No search settings specified, DB is not in a valid state"
|
||||
)
|
||||
logger.info("Fetching updated search settings to verify changes")
|
||||
updated_query = (
|
||||
select(SearchSettings)
|
||||
.where(SearchSettings.status == IndexModelStatus.PRESENT)
|
||||
.order_by(SearchSettings.id.desc())
|
||||
)
|
||||
updated_result = db_session.execute(updated_query)
|
||||
updated_result.scalars().first()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to configure Cohere embedding provider")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Cohere embedding provider: {e}")
|
||||
else:
|
||||
logger.info(
|
||||
logger.error(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
|
||||
)
|
||||
|
||||
@@ -142,20 +142,6 @@ async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
|
||||
# Prefix used for all tenant ids
|
||||
TENANT_ID_PREFIX = "tenant_"
|
||||
|
||||
ALLOWED_SLACK_BOT_TENANT_IDS = os.environ.get("ALLOWED_SLACK_BOT_TENANT_IDS")
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST = (
|
||||
[tenant.strip() for tenant in ALLOWED_SLACK_BOT_TENANT_IDS.split(",")]
|
||||
if ALLOWED_SLACK_BOT_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
IGNORED_SYNCING_TENANT_IDS = os.environ.get("IGNORED_SYNCING_TENANT_ID")
|
||||
IGNORED_SYNCING_TENANT_LIST = (
|
||||
[tenant.strip() for tenant in IGNORED_SYNCING_TENANT_IDS.split(",")]
|
||||
if IGNORED_SYNCING_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
SUPPORTED_EMBEDDING_MODELS = [
|
||||
# Cloud-based models
|
||||
SupportedEmbeddingModel(
|
||||
|
||||
@@ -13,14 +13,6 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
DOMAIN = "test.com"
|
||||
DEFAULT_PASSWORD = "test"
|
||||
|
||||
|
||||
def build_email(name: str) -> str:
|
||||
return f"{name}@test.com"
|
||||
|
||||
|
||||
class UserManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -31,9 +23,9 @@ class UserManager:
|
||||
name = f"test{str(uuid4())}"
|
||||
|
||||
if email is None:
|
||||
email = build_email(name)
|
||||
email = f"{name}@test.com"
|
||||
|
||||
password = DEFAULT_PASSWORD
|
||||
password = "test"
|
||||
|
||||
body = {
|
||||
"email": email,
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create("admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
|
||||
# Create a thread to use in the tests
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads", # Updated endpoint path
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return UUID(response.json()["id"])
|
||||
@@ -1,151 +0,0 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
|
||||
|
||||
|
||||
def test_create_assistant(admin_user: DATestUser | None) -> None:
|
||||
response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"name": "Test Assistant",
|
||||
"description": "A test assistant",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Assistant"
|
||||
assert data["description"] == "A test assistant"
|
||||
assert data["model"] == "gpt-3.5-turbo"
|
||||
assert data["instructions"] == "You are a helpful assistant."
|
||||
|
||||
|
||||
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the assistant
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == "Retrieve Test"
|
||||
|
||||
|
||||
def test_modify_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the assistant
|
||||
response = requests.post(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
json={"name": "Modified Assistant", "instructions": "New instructions"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == "Modified Assistant"
|
||||
assert data["instructions"] == "New instructions"
|
||||
|
||||
|
||||
def test_delete_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, delete the assistant
|
||||
response = requests.delete(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["deleted"] is True
|
||||
|
||||
|
||||
def test_list_assistants(admin_user: DATestUser | None) -> None:
|
||||
# Create multiple assistants
|
||||
for i in range(3):
|
||||
requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the assistants
|
||||
response = requests.get(
|
||||
ASSISTANTS_URL,
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["object"] == "list"
|
||||
assert len(data["data"]) >= 3 # At least the 3 we just created
|
||||
assert all(assistant["object"] == "assistant" for assistant in data["data"])
|
||||
|
||||
|
||||
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
|
||||
# Create 5 assistants
|
||||
for i in range(5):
|
||||
requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# List assistants with limit
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}?limit=2",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 2
|
||||
assert data["has_more"] is True
|
||||
|
||||
# Get next page
|
||||
before = data["last_id"]
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}?limit=2&before={before}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
|
||||
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
|
||||
non_existent_id = -99
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}/{non_existent_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
@@ -1,133 +0,0 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_id(admin_user: Optional[DATestUser]) -> str:
|
||||
response = requests.post(
|
||||
BASE_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
|
||||
json={
|
||||
"role": "user",
|
||||
"content": "Hello, world!",
|
||||
"file_ids": [],
|
||||
"metadata": {"key": "value"},
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["role"] == "user"
|
||||
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
|
||||
assert response_json["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the messages
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["object"] == "list"
|
||||
assert isinstance(response_json["data"], list)
|
||||
assert len(response_json["data"]) > 0
|
||||
assert "first_id" in response_json
|
||||
assert "last_id" in response_json
|
||||
assert "has_more" in response_json
|
||||
|
||||
|
||||
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
create_response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
message_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the message
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == message_id
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["role"] == "user"
|
||||
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
|
||||
|
||||
|
||||
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
create_response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
message_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the message
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||
json={"metadata": {"new_key": "new_value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == message_id
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||
|
||||
|
||||
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
|
||||
non_existent_thread_id = str(uuid.uuid4())
|
||||
non_existent_message_id = -99
|
||||
|
||||
# Test with non-existent thread
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{non_existent_thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test with non-existent message
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
@@ -1,137 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
|
||||
"""Create a run and return its ID."""
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_run(
|
||||
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"instructions": "Test instructions",
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["object"] == "thread.run"
|
||||
assert "created_at" in response_json
|
||||
assert response_json["assistant_id"] == 0
|
||||
assert UUID(response_json["thread_id"]) == thread_id
|
||||
assert response_json["status"] == "queued"
|
||||
assert response_json["model"] == "gpt-3.5-turbo"
|
||||
assert response_json["instructions"] == "Test instructions"
|
||||
|
||||
|
||||
def test_retrieve_run(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
retrieve_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert retrieve_response.status_code == 200
|
||||
|
||||
response_json = retrieve_response.json()
|
||||
assert response_json["id"] == run_id
|
||||
assert response_json["object"] == "thread.run"
|
||||
assert "created_at" in response_json
|
||||
assert UUID(response_json["thread_id"]) == thread_id
|
||||
|
||||
|
||||
def test_cancel_run(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
cancel_response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert cancel_response.status_code == 200
|
||||
|
||||
response_json = cancel_response.json()
|
||||
assert response_json["id"] == run_id
|
||||
assert response_json["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_list_runs(
|
||||
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
# Create a few runs
|
||||
for _ in range(3):
|
||||
requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the runs
|
||||
list_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert list_response.status_code == 200
|
||||
|
||||
response_json = list_response.json()
|
||||
assert isinstance(response_json, list)
|
||||
assert len(response_json) >= 3
|
||||
|
||||
for run in response_json:
|
||||
assert "id" in run
|
||||
assert run["object"] == "thread.run"
|
||||
assert "created_at" in run
|
||||
assert UUID(run["thread_id"]) == thread_id
|
||||
assert "status" in run
|
||||
assert "model" in run
|
||||
|
||||
|
||||
def test_list_run_steps(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
steps_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert steps_response.status_code == 200
|
||||
|
||||
response_json = steps_response.json()
|
||||
assert isinstance(response_json, list)
|
||||
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
|
||||
assert len(response_json) == 0
|
||||
@@ -1,132 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.db.models import ChatSessionSharedStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||
|
||||
|
||||
def test_create_thread(admin_user: DATestUser | None) -> None:
|
||||
response = requests.post(
|
||||
THREADS_URL,
|
||||
json={"messages": None, "metadata": {"key": "value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["object"] == "thread"
|
||||
assert "created_at" in response_json
|
||||
assert response_json["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the thread
|
||||
retrieve_response = requests.get(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert retrieve_response.status_code == 200
|
||||
|
||||
response_json = retrieve_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["object"] == "thread"
|
||||
assert "created_at" in response_json
|
||||
|
||||
|
||||
def test_modify_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the thread
|
||||
modify_response = requests.post(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
json={"metadata": {"new_key": "new_value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert modify_response.status_code == 200
|
||||
|
||||
response_json = modify_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||
|
||||
|
||||
def test_delete_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, delete the thread
|
||||
delete_response = requests.delete(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
response_json = delete_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["object"] == "thread.deleted"
|
||||
assert response_json["deleted"] is True
|
||||
|
||||
|
||||
def test_list_threads(admin_user: DATestUser | None) -> None:
|
||||
# Create a few threads
|
||||
for _ in range(3):
|
||||
requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the threads
|
||||
list_response = requests.get(
|
||||
THREADS_URL,
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert list_response.status_code == 200
|
||||
|
||||
response_json = list_response.json()
|
||||
assert "sessions" in response_json
|
||||
assert len(response_json["sessions"]) >= 3
|
||||
|
||||
for session in response_json["sessions"]:
|
||||
assert "id" in session
|
||||
assert "name" in session
|
||||
assert "persona_id" in session
|
||||
assert "time_created" in session
|
||||
assert "shared_status" in session
|
||||
assert "folder_id" in session
|
||||
assert "current_alternate_model" in session
|
||||
|
||||
# Validate UUID
|
||||
UUID(session["id"])
|
||||
|
||||
# Validate shared_status
|
||||
assert session["shared_status"] in [
|
||||
status.value for status in ChatSessionSharedStatus
|
||||
]
|
||||
@@ -9,11 +9,12 @@ spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-indexing
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 30
|
||||
maxReplicaCount: 10
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing
|
||||
@@ -21,10 +22,10 @@ spec:
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:2
|
||||
@@ -35,6 +36,7 @@ spec:
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:3
|
||||
@@ -42,12 +44,3 @@ spec:
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: cpu
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
|
||||
- type: memory
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
|
||||
@@ -8,11 +8,12 @@ metadata:
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-light
|
||||
minReplicaCount: 5
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 20
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync
|
||||
@@ -22,6 +23,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:2
|
||||
@@ -31,6 +33,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:3
|
||||
@@ -40,6 +43,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion
|
||||
@@ -49,6 +53,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion:2
|
||||
|
||||
@@ -15,6 +15,7 @@ spec:
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery
|
||||
@@ -25,6 +26,7 @@ spec:
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:1
|
||||
@@ -34,6 +36,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:2
|
||||
@@ -43,6 +46,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:3
|
||||
@@ -52,6 +56,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks
|
||||
@@ -61,6 +66,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks:2
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: indexing-model-server-scaledobject
|
||||
namespace: danswer
|
||||
labels:
|
||||
app: indexing-model-server
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: indexing-model-server-deployment
|
||||
pollingInterval: 15 # Check every 15 seconds
|
||||
cooldownPeriod: 30 # Wait 30 seconds before scaling down
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 14
|
||||
triggers:
|
||||
- type: cpu
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
@@ -5,5 +5,5 @@ metadata:
|
||||
namespace: danswer
|
||||
type: Opaque
|
||||
data:
|
||||
host: { base64 encoded host here }
|
||||
password: { base64 encoded password here }
|
||||
host: { { base64-encoded-hostname } }
|
||||
password: { { base64-encoded-password } }
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
@@ -47,10 +47,10 @@ spec:
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "4Gi"
|
||||
memory: "1Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "8Gi"
|
||||
memory: "2Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
ASSISTANT_NAME = "Topic Analyzer"
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that analyzes topics by searching through available \
|
||||
documents and providing insights. These available documents come from common \
|
||||
workplace tools like Slack, emails, Confluence, Google Drive, etc.
|
||||
|
||||
When analyzing a topic:
|
||||
1. Search for relevant information using the search tool
|
||||
2. Synthesize the findings into clear insights
|
||||
3. Highlight key trends, patterns, or notable developments
|
||||
4. Maintain objectivity and cite sources where relevant
|
||||
"""
|
||||
USER_PROMPT = """
|
||||
Please analyze and provide insights about this topic: {topic}.
|
||||
|
||||
IMPORTANT: do not mention things that are not relevant to the specified topic. \
|
||||
If there is no relevant information, just say "No relevant information found."
|
||||
"""
|
||||
|
||||
|
||||
def wait_on_run(client: OpenAI, run, thread):
|
||||
while run.status == "queued" or run.status == "in_progress":
|
||||
run = client.beta.threads.runs.retrieve(
|
||||
thread_id=thread.id,
|
||||
run_id=run.id,
|
||||
)
|
||||
time.sleep(0.5)
|
||||
return run
|
||||
|
||||
|
||||
def show_response(messages) -> None:
|
||||
# Get only the assistant's response text
|
||||
for message in messages.data[::-1]:
|
||||
if message.role == "assistant":
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
print(content.text)
|
||||
break
|
||||
|
||||
|
||||
def analyze_topics(topics: list[str]) -> None:
|
||||
openai_api_key = os.environ.get(
|
||||
"OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"
|
||||
)
|
||||
danswer_api_key = os.environ.get(
|
||||
"DANSWER_API_KEY", "<your Danswer API key if not set as env var>"
|
||||
)
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url="http://localhost:8080/openai-assistants",
|
||||
default_headers={
|
||||
"Authorization": f"Bearer {danswer_api_key}",
|
||||
},
|
||||
)
|
||||
|
||||
# Create an assistant if it doesn't exist
|
||||
try:
|
||||
assistants = client.beta.assistants.list(limit=100)
|
||||
# Find the Topic Analyzer assistant if it exists
|
||||
assistant = next((a for a in assistants.data if a.name == ASSISTANT_NAME))
|
||||
client.beta.assistants.delete(assistant.id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assistant = client.beta.assistants.create(
|
||||
name=ASSISTANT_NAME,
|
||||
instructions=SYSTEM_PROMPT,
|
||||
tools=[{"type": "SearchTool"}], # type: ignore
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Process each topic individually
|
||||
for topic in topics:
|
||||
thread = client.beta.threads.create()
|
||||
message = client.beta.threads.messages.create(
|
||||
thread_id=thread.id,
|
||||
role="user",
|
||||
content=USER_PROMPT.format(topic=topic),
|
||||
)
|
||||
|
||||
run = client.beta.threads.runs.create(
|
||||
thread_id=thread.id,
|
||||
assistant_id=assistant.id,
|
||||
tools=[
|
||||
{ # type: ignore
|
||||
"type": "SearchTool",
|
||||
"retrieval_details": {
|
||||
"run_search": "always",
|
||||
"filters": {
|
||||
"time_cutoff": str(
|
||||
datetime.now(timezone.utc) - timedelta(days=7)
|
||||
)
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
run = wait_on_run(client, run, thread)
|
||||
messages = client.beta.threads.messages.list(
|
||||
thread_id=thread.id, order="asc", after=message.id
|
||||
)
|
||||
print(f"\nAnalysis for topic: {topic}")
|
||||
print("-" * 40)
|
||||
show_response(messages)
|
||||
print()
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Analyze specific topics")
|
||||
parser.add_argument("topics", nargs="+", help="Topics to analyze (one or more)")
|
||||
|
||||
args = parser.parse_args()
|
||||
analyze_topics(args.topics)
|
||||
@@ -68,7 +68,7 @@ export function IndexAttemptStatus({
|
||||
);
|
||||
} else if (status === "in_progress") {
|
||||
badge = (
|
||||
<Badge variant="in_progress" icon={FiClock}>
|
||||
<Badge className="flex-none" variant="in_progress" icon={FiClock}>
|
||||
In Progress
|
||||
</Badge>
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user