mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
6 Commits
improved_c
...
user_defau
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fff98ddc15 | ||
|
|
f4a020b599 | ||
|
|
5166649eae | ||
|
|
ba805f766f | ||
|
|
9d57f34c34 | ||
|
|
cc2f584321 |
@@ -1,24 +1,20 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -28,7 +24,7 @@ jobs:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
@@ -45,24 +41,31 @@ jobs:
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
@@ -288,6 +288,15 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
||||
@@ -23,6 +23,56 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
||||
@@ -100,6 +100,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -463,8 +468,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
has_web_login = attributes.get_attribute(user, "has_web_login")
|
||||
|
||||
if not has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
raise BasicAuthenticationError(
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -621,14 +625,12 @@ async def double_check_user(
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
@@ -637,8 +639,7 @@ async def double_check_user(
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
@@ -664,15 +665,13 @@ async def current_curator_or_admin_user(
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
@@ -684,8 +683,7 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
@@ -73,15 +72,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
IGNORED_SYNCING_TENANT_LIST
|
||||
and tenant_id in IGNORED_SYNCING_TENANT_LIST
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
@@ -82,11 +81,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
|
||||
96
backend/danswer/background/celery/apps/scheduler.py
Normal file
96
backend/danswer/background/celery/apps/scheduler.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=1)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
super().setup_schedule()
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reloading schedule to check for new tenants...")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Checking for tenant task updates...")
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = getattr(self, "_store", {"entries": {}}).get(
|
||||
"entries", {}
|
||||
)
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name in current_schedule.keys():
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Found new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Updating schedule",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
if not hasattr(self, "_store"):
|
||||
self._store: dict[str, dict] = {"entries": {}}
|
||||
self.update_from_dict(new_beat_schedule)
|
||||
logger.info(f"New schedule: {new_beat_schedule}")
|
||||
|
||||
logger.info("Tenant tasks updated successfully")
|
||||
else:
|
||||
logger.debug("No schedule updates needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
current_tasks = set(current_schedule.keys())
|
||||
new_tasks = set(new_schedule.keys())
|
||||
return current_tasks != new_tasks
|
||||
@@ -8,7 +8,7 @@ tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
@@ -20,13 +20,13 @@ tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -29,26 +29,18 @@ JobStatusType = (
|
||||
def _initializer(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
||||
"""Ensure the parent proc's database connections are not touched
|
||||
in the new connection pool
|
||||
|
||||
Based on SQLAlchemy's recommendations to handle multiprocessing:
|
||||
Based on the recommended approach in the SQLAlchemy docs found:
|
||||
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
|
||||
# Reset the engine in the child process
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
# Optionally set a custom app name for database logging purposes
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
|
||||
# Initialize a new engine with desired parameters
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
|
||||
# Proceed with executing the target function
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -19,16 +19,10 @@ from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -41,7 +35,6 @@ from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
@@ -61,14 +54,13 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
@@ -77,14 +69,14 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_constructor import construct_tools
|
||||
from danswer.tools.tool_constructor import CustomToolConfig
|
||||
from danswer.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from danswer.tools.tool_constructor import InternetSearchToolConfig
|
||||
from danswer.tools.tool_constructor import SearchToolConfig
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
@@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
@@ -122,9 +111,6 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@@ -295,7 +281,6 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -307,6 +292,9 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
@@ -428,12 +416,20 @@ def stream_chat_message_objects(
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
if existing_assistant_message_id is None:
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
else:
|
||||
if final_msg.id != existing_assistant_message_id:
|
||||
raise RuntimeError(
|
||||
"The last message was not the existing assistant message. "
|
||||
f"Final message id: {final_msg.id}, "
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
@@ -504,13 +500,19 @@ def stream_chat_message_objects(
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
|
||||
# we don't need to reserve a message id if we're using an existing assistant message
|
||||
reserved_message_id = (
|
||||
final_msg.id
|
||||
if existing_assistant_message_id is not None
|
||||
else reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
@@ -525,7 +527,13 @@ def stream_chat_message_objects(
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=final_msg,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
@@ -537,6 +545,7 @@ def stream_chat_message_objects(
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
@@ -560,142 +569,39 @@ def stream_chat_message_objects(
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_style_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
elif (
|
||||
llm.config.model_provider == "azure"
|
||||
and AZURE_DALLE_API_KEY is not None
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=bing_api_key,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
@@ -871,7 +777,6 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@@ -879,9 +784,11 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None
|
||||
),
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
@@ -915,7 +822,6 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -925,7 +831,6 @@ def stream_chat_message(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
|
||||
@@ -55,11 +55,11 @@ def validate_channel_names(
|
||||
# Scaling configurations for multi-tenant Slack bot handling
|
||||
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
|
||||
TENANT_HEARTBEAT_INTERVAL = (
|
||||
15 # How often pods send heartbeats to indicate they are still processing a tenant
|
||||
60 # How often pods send heartbeats to indicate they are still processing a tenant
|
||||
)
|
||||
TENANT_HEARTBEAT_EXPIRATION = (
|
||||
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
TENANT_ACQUISITION_INTERVAL = (
|
||||
60 # How often pods attempt to acquire unprocessed tenants
|
||||
)
|
||||
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
|
||||
|
||||
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))
|
||||
|
||||
@@ -75,7 +75,6 @@ from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -165,15 +164,9 @@ class SlackbotHandler:
|
||||
|
||||
def acquire_tenants(self) -> None:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
|
||||
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
):
|
||||
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
|
||||
continue
|
||||
|
||||
if tenant_id in self.tenant_ids:
|
||||
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
|
||||
continue
|
||||
@@ -197,9 +190,6 @@ class SlackbotHandler:
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}")
|
||||
self.tenant_ids.add(tenant_id)
|
||||
|
||||
for tenant_id in self.tenant_ids:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
@@ -246,14 +236,14 @@ class SlackbotHandler:
|
||||
|
||||
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
|
||||
|
||||
if self.socket_clients.get(tenant_id):
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
|
||||
self.start_socket_client(tenant_id, slack_bot_tokens)
|
||||
|
||||
except KvKeyNotFoundError:
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if self.socket_clients.get(tenant_id):
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
del self.socket_clients[tenant_id]
|
||||
del self.slack_bot_tokens[tenant_id]
|
||||
@@ -287,14 +277,14 @@ class SlackbotHandler:
|
||||
logger.info(f"Connecting socket client for tenant {tenant_id}")
|
||||
socket_client.connect()
|
||||
self.socket_clients[tenant_id] = socket_client
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
|
||||
|
||||
def stop_socket_clients(self) -> None:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
for tenant_id, client in self.socket_clients.items():
|
||||
if client:
|
||||
asyncio.run(client.close())
|
||||
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
|
||||
asyncio.run(client.close())
|
||||
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
|
||||
|
||||
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
|
||||
if not self.running:
|
||||
@@ -308,16 +298,6 @@ class SlackbotHandler:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
self.stop_socket_clients()
|
||||
|
||||
# Release locks for all tenants
|
||||
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in self.tenant_ids:
|
||||
try:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
|
||||
|
||||
# Wait for background threads to finish (with timeout)
|
||||
logger.info("Waiting for background threads to finish...")
|
||||
self.acquire_thread.join(timeout=5)
|
||||
|
||||
@@ -189,13 +189,6 @@ class SqlEngine:
|
||||
return ""
|
||||
return cls._app_name
|
||||
|
||||
@classmethod
|
||||
def reset_engine(cls) -> None:
|
||||
with cls._lock:
|
||||
if cls._engine:
|
||||
cls._engine.dispose()
|
||||
cls._engine = None
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
|
||||
@@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
||||
return tool
|
||||
|
||||
|
||||
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
|
||||
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
|
||||
if not tool:
|
||||
raise ValueError("Tool by specified name does not exist")
|
||||
return tool
|
||||
|
||||
|
||||
def create_tool(
|
||||
name: str,
|
||||
description: str | None,
|
||||
@@ -37,7 +44,7 @@ def create_tool(
|
||||
description=description,
|
||||
in_code_tool_id=None,
|
||||
openapi_schema=openapi_schema,
|
||||
custom_headers=[header.dict() for header in custom_headers]
|
||||
custom_headers=[header.model_dump() for header in custom_headers]
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
|
||||
@@ -25,6 +25,7 @@ from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
@@ -73,6 +74,9 @@ from danswer.server.manage.search_settings import router as search_settings_rout
|
||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from danswer.server.manage.users import router as user_router
|
||||
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
)
|
||||
from danswer.server.query_and_chat.chat_backend import router as chat_router
|
||||
from danswer.server.query_and_chat.query_backend import (
|
||||
admin_router as admin_query_router,
|
||||
@@ -194,7 +198,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
if status_code >= 400:
|
||||
|
||||
if isinstance(exc, BasicAuthenticationError):
|
||||
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
|
||||
logger.error(f"Authentication failed: {str(exc)}")
|
||||
|
||||
elif status_code >= 400:
|
||||
error_msg = f"{str(exc)}\n"
|
||||
error_msg += "".join(traceback.format_tb(exc.__traceback__))
|
||||
logger.error(error_msg)
|
||||
@@ -220,7 +229,6 @@ def get_application() -> FastAPI:
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
# Add the custom exception handler
|
||||
application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error)
|
||||
application.add_exception_handler(status.HTTP_401_UNAUTHORIZED, log_http_error)
|
||||
application.add_exception_handler(status.HTTP_403_FORBIDDEN, log_http_error)
|
||||
@@ -265,6 +273,9 @@ def get_application() -> FastAPI:
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
@@ -63,7 +63,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
|
||||
@@ -11,7 +11,6 @@ from fastapi import Body
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from psycopg2.errors import UniqueViolation
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
@@ -27,6 +26,7 @@ from danswer.auth.noauth_user import fetch_no_auth_user
|
||||
from danswer.auth.noauth_user import set_no_auth_user_preferences
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserStatus
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
@@ -492,13 +492,10 @@ def verify_user_logged_in(
|
||||
store = get_kv_store()
|
||||
return fetch_no_auth_user(store)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User Not Authenticated"
|
||||
)
|
||||
raise BasicAuthenticationError(detail="User Not Authenticated")
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
|
||||
273
backend/danswer/server/openai_assistants_api/asssistants_api.py
Normal file
273
backend/danswer/server/openai_assistants_api/asssistants_api.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import get_personas
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
from danswer.db.tools import get_tool_by_name
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter(prefix="/assistants")
|
||||
|
||||
|
||||
# Base models
|
||||
class AssistantObject(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant"
|
||||
created_at: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
model: str
|
||||
instructions: Optional[str] = None
|
||||
tools: list[dict[str, Any]]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
model: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ModifyAssistantRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DeleteAssistantResponse(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant.deleted"
|
||||
deleted: bool
|
||||
|
||||
|
||||
class ListAssistantsResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: list[AssistantObject]
|
||||
first_id: Optional[int] = None
|
||||
last_id: Optional[int] = None
|
||||
has_more: bool
|
||||
|
||||
|
||||
def persona_to_assistant(persona: Persona) -> AssistantObject:
|
||||
return AssistantObject(
|
||||
id=persona.id,
|
||||
created_at=0,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
model=persona.llm_model_version_override or "gpt-3.5-turbo",
|
||||
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
|
||||
tools=[
|
||||
{
|
||||
"type": tool.display_name,
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"schema": tool.openapi_schema,
|
||||
},
|
||||
}
|
||||
for tool in persona.tools
|
||||
],
|
||||
file_ids=[], # Assuming no file support for now
|
||||
metadata={}, # Assuming no metadata for now
|
||||
)
|
||||
|
||||
|
||||
# API endpoints
|
||||
@router.post("")
|
||||
def create_assistant(
|
||||
request: CreateAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
prompt = None
|
||||
if request.instructions:
|
||||
prompt = upsert_prompt(
|
||||
user=user,
|
||||
name=f"Prompt for {request.name or 'New Assistant'}",
|
||||
description="Auto-generated prompt",
|
||||
system_prompt=request.instructions,
|
||||
task_prompt="",
|
||||
include_citations=True,
|
||||
datetime_aware=True,
|
||||
personas=[],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
tool_ids = []
|
||||
for tool in request.tools or []:
|
||||
tool_type = tool.get("type")
|
||||
if not tool_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_db = get_tool_by_name(tool_type, db_session)
|
||||
tool_ids.append(tool_db.id)
|
||||
except ValueError:
|
||||
# Skip tools that don't exist in the database
|
||||
logger.error(f"Tool {tool_type} not found in database")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tool {tool_type} not found in database"
|
||||
)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=request.name or f"Assistant-{uuid4()}",
|
||||
description=request.description or "",
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=request.model,
|
||||
starter_messages=None,
|
||||
is_public=False,
|
||||
db_session=db_session,
|
||||
prompt_ids=[prompt.id] if prompt else [0],
|
||||
document_set_ids=[],
|
||||
tool_ids=tool_ids,
|
||||
icon_color=None,
|
||||
icon_shape=None,
|
||||
is_visible=True,
|
||||
)
|
||||
|
||||
if prompt:
|
||||
prompt.personas = [persona]
|
||||
db_session.commit()
|
||||
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
""
|
||||
|
||||
|
||||
@router.get("/{assistant_id}")
|
||||
def retrieve_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
try:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
except ValueError:
|
||||
persona = None
|
||||
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.post("/{assistant_id}")
|
||||
def modify_assistant(
|
||||
assistant_id: int,
|
||||
request: ModifyAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=True,
|
||||
)
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(persona, key, value)
|
||||
|
||||
if "instructions" in update_data and persona.prompts:
|
||||
persona.prompts[0].system_prompt = update_data["instructions"]
|
||||
|
||||
db_session.commit()
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.delete("/{assistant_id}")
|
||||
def delete_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DeleteAssistantResponse:
|
||||
try:
|
||||
mark_persona_as_deleted(
|
||||
persona_id=int(assistant_id),
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return DeleteAssistantResponse(id=assistant_id, deleted=True)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_assistants(
|
||||
limit: int = Query(20, le=100),
|
||||
order: str = Query("desc", regex="^(asc|desc)$"),
|
||||
after: Optional[int] = None,
|
||||
before: Optional[int] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListAssistantsResponse:
|
||||
personas = list(
|
||||
get_personas(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
get_editable=False,
|
||||
joinedload_all=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
personas = [p for p in personas if p.id > int(after)]
|
||||
if before:
|
||||
personas = [p for p in personas if p.id < int(before)]
|
||||
|
||||
# Apply ordering
|
||||
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
personas = personas[:limit]
|
||||
|
||||
assistants = [persona_to_assistant(p) for p in personas]
|
||||
|
||||
return ListAssistantsResponse(
|
||||
data=assistants,
|
||||
first_id=assistants[0].id if assistants else None,
|
||||
last_id=assistants[-1].id if assistants else None,
|
||||
has_more=len(personas) == limit,
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.server.openai_assistants_api.asssistants_api import (
|
||||
router as assistants_router,
|
||||
)
|
||||
from danswer.server.openai_assistants_api.messages_api import router as messages_router
|
||||
from danswer.server.openai_assistants_api.runs_api import router as runs_router
|
||||
from danswer.server.openai_assistants_api.threads_api import router as threads_router
|
||||
|
||||
|
||||
def get_full_openai_assistants_api_router() -> APIRouter:
|
||||
router = APIRouter(prefix="/openai-assistants")
|
||||
|
||||
router.include_router(assistants_router)
|
||||
router.include_router(runs_router)
|
||||
router.include_router(threads_router)
|
||||
router.include_router(messages_router)
|
||||
|
||||
return router
|
||||
235
backend/danswer/server/openai_assistants_api/messages_api.py
Normal file
235
backend/danswer/server/openai_assistants_api/messages_api.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
|
||||
router = APIRouter(prefix="")
|
||||
|
||||
|
||||
Role = Literal["user", "assistant"]
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
|
||||
object: Literal["thread.message"] = "thread.message"
|
||||
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
thread_id: str
|
||||
role: Role
|
||||
content: list[MessageContent]
|
||||
file_ids: list[str] = []
|
||||
assistant_id: Optional[str] = None
|
||||
run_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
|
||||
|
||||
|
||||
class CreateMessageRequest(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
file_ids: list[str] = []
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class ListMessagesResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: list[Message]
|
||||
first_id: str
|
||||
last_id: str
|
||||
has_more: bool
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages")
|
||||
def create_message(
|
||||
thread_id: str,
|
||||
message: CreateMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message=message.content,
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
token_count=check_number_of_tokens(message.content),
|
||||
message_type=(
|
||||
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return Message(
|
||||
id=str(new_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user",
|
||||
content=[MessageContent(type="text", text=message.content)],
|
||||
file_ids=message.file_ids,
|
||||
metadata=message.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages")
|
||||
def list_messages(
|
||||
thread_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListMessagesResponse:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
messages = [m for m in messages if str(m.id) >= after]
|
||||
if before:
|
||||
messages = [m for m in messages if str(m.id) <= before]
|
||||
|
||||
# Apply ordering
|
||||
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
messages = messages[:limit]
|
||||
|
||||
data = [
|
||||
Message(
|
||||
id=str(m.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if m.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=m.message)],
|
||||
created_at=int(m.time_sent.timestamp()),
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
return ListMessagesResponse(
|
||||
data=data,
|
||||
first_id=str(data[0].id) if data else "",
|
||||
last_id=str(data[-1].id) if data else "",
|
||||
has_more=len(messages) == limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages/{message_id}")
|
||||
def retrieve_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
class ModifyMessageRequest(BaseModel):
|
||||
metadata: dict
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages/{message_id}")
|
||||
def modify_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
request: ModifyMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Update metadata
|
||||
# TODO: Uncomment this once we have metadata in the chat message
|
||||
# chat_message.metadata = request.metadata
|
||||
# db_session.commit()
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
344
backend/danswer/server/openai_assistants_api/runs_api.py
Normal file
344
backend/danswer/server/openai_assistants_api/runs_api.py
Normal file
@@ -0,0 +1,344 @@
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
assistant_id: int
|
||||
model: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
additional_instructions: Optional[str] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
RunStatus = Literal[
|
||||
"queued",
|
||||
"in_progress",
|
||||
"requires_action",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
"failed",
|
||||
"completed",
|
||||
"expired",
|
||||
]
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["thread.run"]
|
||||
created_at: int
|
||||
assistant_id: int
|
||||
thread_id: UUID
|
||||
status: RunStatus
|
||||
started_at: Optional[int] = None
|
||||
expires_at: Optional[int] = None
|
||||
cancelled_at: Optional[int] = None
|
||||
failed_at: Optional[int] = None
|
||||
completed_at: Optional[int] = None
|
||||
last_error: Optional[dict] = None
|
||||
model: str
|
||||
instructions: str
|
||||
tools: list[dict]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
def process_run_in_background(
|
||||
message_id: int,
|
||||
parent_message_id: int,
|
||||
chat_session_id: UUID,
|
||||
assistant_id: int,
|
||||
instructions: str,
|
||||
tools: list[dict],
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Get the latest message in the chat session
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_tool_retrieval_details = RetrievalDetails()
|
||||
for tool in tools:
|
||||
if tool["type"] == SearchTool.__name__ and (
|
||||
retrieval_details := tool.get("retrieval_details")
|
||||
):
|
||||
search_tool_retrieval_details = RetrievalDetails.model_validate(
|
||||
retrieval_details
|
||||
)
|
||||
break
|
||||
|
||||
new_msg_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=int(parent_message_id) if parent_message_id else None,
|
||||
message=instructions,
|
||||
file_descriptors=[],
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||
query_override=None,
|
||||
regenerate=None,
|
||||
llm_override=None,
|
||||
prompt_override=None,
|
||||
alternate_assistant_id=assistant_id,
|
||||
use_existing_user_message=True,
|
||||
existing_assistant_message_id=message_id,
|
||||
)
|
||||
|
||||
run_message = get_chat_message(message_id, user.id if user else None, db_session)
|
||||
try:
|
||||
for packet in stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
):
|
||||
if isinstance(packet, ChatMessageDetail):
|
||||
# Update the run status and message content
|
||||
run_message = get_chat_message(
|
||||
message_id, user.id if user else None, db_session
|
||||
)
|
||||
if run_message:
|
||||
# this handles cancelling
|
||||
if run_message.error:
|
||||
return
|
||||
|
||||
run_message.message = packet.message
|
||||
run_message.message_type = MessageType.ASSISTANT
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception("Error processing run in background")
|
||||
run_message.error = str(e)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
db_session.refresh(run_message)
|
||||
if run_message.token_count == 0:
|
||||
run_message.error = "No tokens generated"
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs")
|
||||
def create_run(
|
||||
thread_id: UUID,
|
||||
run_request: RunRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
# Create a new "run" (chat message) in the session
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message="",
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.flush()
|
||||
latest_message.latest_child_message = new_message.id
|
||||
db_session.commit()
|
||||
|
||||
# Schedule the background task
|
||||
background_tasks.add_task(
|
||||
process_run_in_background,
|
||||
new_message.id,
|
||||
latest_message.id,
|
||||
chat_session.id,
|
||||
run_request.assistant_id,
|
||||
run_request.instructions or "",
|
||||
run_request.tools or [],
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return RunResponse(
|
||||
id=str(new_message.id),
|
||||
object="thread.run",
|
||||
created_at=int(new_message.time_sent.timestamp()),
|
||||
assistant_id=run_request.assistant_id,
|
||||
thread_id=chat_session.id,
|
||||
status="queued",
|
||||
model=run_request.model or "default_model",
|
||||
instructions=run_request.instructions or "",
|
||||
tools=run_request.tools or [],
|
||||
file_ids=[],
|
||||
metadata=run_request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}")
|
||||
def retrieve_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# Retrieve the chat message (which represents a "run" in DAnswer)
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=int(run_id), # Convert string run_id to int
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_session = chat_message.chat_session
|
||||
|
||||
# Map DAnswer status to OpenAI status
|
||||
run_status: RunStatus = "queued"
|
||||
if chat_message.message:
|
||||
run_status = "in_progress"
|
||||
if chat_message.token_count != 0:
|
||||
run_status = "completed"
|
||||
if chat_message.error:
|
||||
run_status = "cancelled"
|
||||
|
||||
return RunResponse(
|
||||
id=run_id,
|
||||
object="thread.run",
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
assistant_id=chat_session.persona_id or 0,
|
||||
thread_id=chat_session.id,
|
||||
status=run_status,
|
||||
started_at=int(chat_message.time_sent.timestamp()),
|
||||
completed_at=(
|
||||
int(chat_message.time_sent.timestamp()) if chat_message.message else None
|
||||
),
|
||||
model=chat_session.current_alternate_model or "default_model",
|
||||
instructions="", # DAnswer doesn't store per-message instructions
|
||||
tools=[], # DAnswer doesn't have a direct equivalent for tools
|
||||
file_ids=(
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
),
|
||||
metadata=None, # DAnswer doesn't store metadata for individual messages
|
||||
)
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
|
||||
def cancel_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# In DAnswer, we don't have a direct equivalent to cancelling a run
|
||||
# We'll simulate it by marking the message as "cancelled"
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_message.error = "Cancelled"
|
||||
db_session.commit()
|
||||
|
||||
return retrieve_run(thread_id, run_id, user, db_session)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs")
|
||||
def list_runs(
|
||||
thread_id: UUID,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[RunResponse]:
|
||||
# In DAnswer, we'll treat each message in a chat session as a "run"
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
if after:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
|
||||
if before:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
|
||||
|
||||
# Apply ordering
|
||||
chat_messages = sorted(
|
||||
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
|
||||
)
|
||||
|
||||
# Apply limit
|
||||
chat_messages = chat_messages[:limit]
|
||||
|
||||
return [
|
||||
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
|
||||
def list_run_steps(
|
||||
run_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[dict]: # You may want to create a specific model for run steps
|
||||
# DAnswer doesn't have an equivalent to run steps
|
||||
# We'll return an empty list to maintain API compatibility
|
||||
return []
|
||||
|
||||
|
||||
# Additional helper functions can be added here if needed
|
||||
156
backend/danswer/server/openai_assistants_api/threads_api.py
Normal file
156
backend/danswer/server/openai_assistants_api/threads_api.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import delete_chat_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_chat_sessions_by_user
|
||||
from danswer.db.chat import update_chat_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.server.query_and_chat.models import ChatSessionDetails
|
||||
from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter(prefix="/threads")
|
||||
|
||||
|
||||
# Models
|
||||
class Thread(BaseModel):
|
||||
id: UUID
|
||||
object: str = "thread"
|
||||
created_at: int
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
messages: Optional[list[dict]] = None
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class ModifyThreadRequest(BaseModel):
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
# API Endpoints
|
||||
@router.post("")
|
||||
def create_thread(
|
||||
request: CreateThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave the naming till later to prevent delay
|
||||
user_id=user_id,
|
||||
persona_id=0,
|
||||
)
|
||||
|
||||
return Thread(
|
||||
id=new_chat_session.id,
|
||||
created_at=int(new_chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
def retrieve_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=None, # Assuming we don't store metadata in our current implementation
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}")
|
||||
def modify_thread(
|
||||
thread_id: UUID,
|
||||
request: ModifyThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = update_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
description=None, # Not updating description
|
||||
sharing_status=None, # Not updating sharing status
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}")
|
||||
def delete_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_threads(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id if user else None
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
id=chat.id,
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
)
|
||||
@@ -347,7 +347,6 @@ def handle_new_chat_message(
|
||||
for packet in stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
|
||||
@@ -108,6 +108,9 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# used for "OpenAI Assistants API"
|
||||
existing_assistant_message_id: int | None = None
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
255
backend/danswer/tools/tool_constructor.py
Normal file
255
backend/danswer/tools/tool_constructor.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
"""Helper function to get image generation LLM config based on available providers"""
|
||||
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||
return LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
|
||||
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
|
||||
return LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
|
||||
# Fallback to checking for OpenAI provider in database
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError("Image generation tool requires an OpenAI API key")
|
||||
|
||||
return LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
|
||||
|
||||
class SearchToolConfig(BaseModel):
|
||||
answer_style_config: AnswerStyleConfig = Field(
|
||||
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
|
||||
)
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
full_doc: bool = False
|
||||
latest_query_files: list[InMemoryChatFile] | None = None
|
||||
|
||||
|
||||
class InternetSearchToolConfig(BaseModel):
|
||||
answer_style_config: AnswerStyleConfig = Field(
|
||||
default_factory=lambda: AnswerStyleConfig(
|
||||
citation_config=CitationConfig(all_docs_useful=True)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationToolConfig(BaseModel):
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
class CustomToolConfig(BaseModel):
|
||||
chat_session_id: UUID | None = None
|
||||
message_id: int | None = None
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
prompt_config: PromptConfig,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
internet_search_tool_config: InternetSearchToolConfig | None = None,
|
||||
image_generation_tool_config: ImageGenerationToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
for db_tool_model in persona.tools:
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
|
||||
# Handle Search Tool
|
||||
if tool_cls.__name__ == SearchTool.__name__:
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
|
||||
# Handle Image Generation Tool
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
if not image_generation_tool_config:
|
||||
image_generation_tool_config = ImageGenerationToolConfig()
|
||||
|
||||
img_generation_llm_config = _get_image_generation_config(
|
||||
llm, db_session
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=image_generation_tool_config.additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Internet Search Tool
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
if not internet_search_tool_config:
|
||||
internet_search_tool_config = InternetSearchToolConfig()
|
||||
|
||||
if not BING_API_KEY:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=BING_API_KEY,
|
||||
answer_style_config=internet_search_tool_config.answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle custom tools
|
||||
elif db_tool_model.openapi_schema:
|
||||
if not custom_tool_config:
|
||||
custom_tool_config = CustomToolConfig()
|
||||
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=custom_tool_config.chat_session_id,
|
||||
message_id=custom_tool_config.message_id,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_config.additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
if search_tool_config:
|
||||
search_tool_config.document_pruning_config.tool_num_tokens = (
|
||||
compute_all_tool_tokens(
|
||||
tools,
|
||||
get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
),
|
||||
)
|
||||
)
|
||||
search_tool_config.document_pruning_config.using_tool_message = (
|
||||
explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
)
|
||||
|
||||
return tool_dict
|
||||
@@ -1,5 +1,6 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -139,8 +140,19 @@ def fetch_ee_implementation_or_noop(
|
||||
Exception: If EE is enabled but the fetch fails.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
return lambda *args, **kwargs: noop_return_value
|
||||
if inspect.iscoroutinefunction(noop_return_value):
|
||||
|
||||
async def async_noop(*args: Any, **kwargs: Any) -> Any:
|
||||
return await noop_return_value(*args, **kwargs)
|
||||
|
||||
return async_noop
|
||||
|
||||
else:
|
||||
|
||||
def sync_noop(*args: Any, **kwargs: Any) -> Any:
|
||||
return noop_return_value
|
||||
|
||||
return sync_noop
|
||||
try:
|
||||
return fetch_versioned_implementation(module, attribute)
|
||||
except Exception as e:
|
||||
|
||||
@@ -142,20 +142,6 @@ async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
|
||||
# Prefix used for all tenant ids
|
||||
TENANT_ID_PREFIX = "tenant_"
|
||||
|
||||
ALLOWED_SLACK_BOT_TENANT_IDS = os.environ.get("ALLOWED_SLACK_BOT_TENANT_IDS")
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST = (
|
||||
[tenant.strip() for tenant in ALLOWED_SLACK_BOT_TENANT_IDS.split(",")]
|
||||
if ALLOWED_SLACK_BOT_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
IGNORED_SYNCING_TENANT_IDS = os.environ.get("IGNORED_SYNCING_TENANT_ID")
|
||||
IGNORED_SYNCING_TENANT_LIST = (
|
||||
[tenant.strip() for tenant in IGNORED_SYNCING_TENANT_IDS.split(",")]
|
||||
if IGNORED_SYNCING_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
SUPPORTED_EMBEDDING_MODELS = [
|
||||
# Cloud-based models
|
||||
SupportedEmbeddingModel(
|
||||
|
||||
@@ -13,6 +13,14 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
DOMAIN = "test.com"
|
||||
DEFAULT_PASSWORD = "test"
|
||||
|
||||
|
||||
def build_email(name: str) -> str:
|
||||
return f"{name}@test.com"
|
||||
|
||||
|
||||
class UserManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -23,9 +31,9 @@ class UserManager:
|
||||
name = f"test{str(uuid4())}"
|
||||
|
||||
if email is None:
|
||||
email = f"{name}@test.com"
|
||||
email = build_email(name)
|
||||
|
||||
password = "test"
|
||||
password = DEFAULT_PASSWORD
|
||||
|
||||
body = {
|
||||
"email": email,
|
||||
|
||||
55
backend/tests/integration/openai_assistants_api/conftest.py
Normal file
55
backend/tests/integration/openai_assistants_api/conftest.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create("admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
|
||||
# Create a thread to use in the tests
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads", # Updated endpoint path
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return UUID(response.json()["id"])
|
||||
@@ -0,0 +1,151 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
|
||||
|
||||
|
||||
def test_create_assistant(admin_user: DATestUser | None) -> None:
|
||||
response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"name": "Test Assistant",
|
||||
"description": "A test assistant",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Assistant"
|
||||
assert data["description"] == "A test assistant"
|
||||
assert data["model"] == "gpt-3.5-turbo"
|
||||
assert data["instructions"] == "You are a helpful assistant."
|
||||
|
||||
|
||||
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the assistant
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == "Retrieve Test"
|
||||
|
||||
|
||||
def test_modify_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the assistant
|
||||
response = requests.post(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
json={"name": "Modified Assistant", "instructions": "New instructions"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == "Modified Assistant"
|
||||
assert data["instructions"] == "New instructions"
|
||||
|
||||
|
||||
def test_delete_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, delete the assistant
|
||||
response = requests.delete(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["deleted"] is True
|
||||
|
||||
|
||||
def test_list_assistants(admin_user: DATestUser | None) -> None:
|
||||
# Create multiple assistants
|
||||
for i in range(3):
|
||||
requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the assistants
|
||||
response = requests.get(
|
||||
ASSISTANTS_URL,
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["object"] == "list"
|
||||
assert len(data["data"]) >= 3 # At least the 3 we just created
|
||||
assert all(assistant["object"] == "assistant" for assistant in data["data"])
|
||||
|
||||
|
||||
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
|
||||
# Create 5 assistants
|
||||
for i in range(5):
|
||||
requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# List assistants with limit
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}?limit=2",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 2
|
||||
assert data["has_more"] is True
|
||||
|
||||
# Get next page
|
||||
before = data["last_id"]
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}?limit=2&before={before}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
|
||||
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
|
||||
non_existent_id = -99
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}/{non_existent_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
133
backend/tests/integration/openai_assistants_api/test_messages.py
Normal file
133
backend/tests/integration/openai_assistants_api/test_messages.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_id(admin_user: Optional[DATestUser]) -> str:
|
||||
response = requests.post(
|
||||
BASE_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
|
||||
json={
|
||||
"role": "user",
|
||||
"content": "Hello, world!",
|
||||
"file_ids": [],
|
||||
"metadata": {"key": "value"},
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["role"] == "user"
|
||||
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
|
||||
assert response_json["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the messages
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["object"] == "list"
|
||||
assert isinstance(response_json["data"], list)
|
||||
assert len(response_json["data"]) > 0
|
||||
assert "first_id" in response_json
|
||||
assert "last_id" in response_json
|
||||
assert "has_more" in response_json
|
||||
|
||||
|
||||
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
create_response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
message_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the message
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == message_id
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["role"] == "user"
|
||||
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
|
||||
|
||||
|
||||
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
create_response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
message_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the message
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||
json={"metadata": {"new_key": "new_value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == message_id
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||
|
||||
|
||||
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
|
||||
non_existent_thread_id = str(uuid.uuid4())
|
||||
non_existent_message_id = -99
|
||||
|
||||
# Test with non-existent thread
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{non_existent_thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test with non-existent message
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
137
backend/tests/integration/openai_assistants_api/test_runs.py
Normal file
137
backend/tests/integration/openai_assistants_api/test_runs.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
|
||||
"""Create a run and return its ID."""
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_run(
|
||||
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"instructions": "Test instructions",
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["object"] == "thread.run"
|
||||
assert "created_at" in response_json
|
||||
assert response_json["assistant_id"] == 0
|
||||
assert UUID(response_json["thread_id"]) == thread_id
|
||||
assert response_json["status"] == "queued"
|
||||
assert response_json["model"] == "gpt-3.5-turbo"
|
||||
assert response_json["instructions"] == "Test instructions"
|
||||
|
||||
|
||||
def test_retrieve_run(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
retrieve_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert retrieve_response.status_code == 200
|
||||
|
||||
response_json = retrieve_response.json()
|
||||
assert response_json["id"] == run_id
|
||||
assert response_json["object"] == "thread.run"
|
||||
assert "created_at" in response_json
|
||||
assert UUID(response_json["thread_id"]) == thread_id
|
||||
|
||||
|
||||
def test_cancel_run(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
cancel_response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert cancel_response.status_code == 200
|
||||
|
||||
response_json = cancel_response.json()
|
||||
assert response_json["id"] == run_id
|
||||
assert response_json["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_list_runs(
|
||||
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
# Create a few runs
|
||||
for _ in range(3):
|
||||
requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the runs
|
||||
list_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert list_response.status_code == 200
|
||||
|
||||
response_json = list_response.json()
|
||||
assert isinstance(response_json, list)
|
||||
assert len(response_json) >= 3
|
||||
|
||||
for run in response_json:
|
||||
assert "id" in run
|
||||
assert run["object"] == "thread.run"
|
||||
assert "created_at" in run
|
||||
assert UUID(run["thread_id"]) == thread_id
|
||||
assert "status" in run
|
||||
assert "model" in run
|
||||
|
||||
|
||||
def test_list_run_steps(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
steps_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert steps_response.status_code == 200
|
||||
|
||||
response_json = steps_response.json()
|
||||
assert isinstance(response_json, list)
|
||||
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
|
||||
assert len(response_json) == 0
|
||||
132
backend/tests/integration/openai_assistants_api/test_threads.py
Normal file
132
backend/tests/integration/openai_assistants_api/test_threads.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.db.models import ChatSessionSharedStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||
|
||||
|
||||
def test_create_thread(admin_user: DATestUser | None) -> None:
|
||||
response = requests.post(
|
||||
THREADS_URL,
|
||||
json={"messages": None, "metadata": {"key": "value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["object"] == "thread"
|
||||
assert "created_at" in response_json
|
||||
assert response_json["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the thread
|
||||
retrieve_response = requests.get(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert retrieve_response.status_code == 200
|
||||
|
||||
response_json = retrieve_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["object"] == "thread"
|
||||
assert "created_at" in response_json
|
||||
|
||||
|
||||
def test_modify_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the thread
|
||||
modify_response = requests.post(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
json={"metadata": {"new_key": "new_value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert modify_response.status_code == 200
|
||||
|
||||
response_json = modify_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||
|
||||
|
||||
def test_delete_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, delete the thread
|
||||
delete_response = requests.delete(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
response_json = delete_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["object"] == "thread.deleted"
|
||||
assert response_json["deleted"] is True
|
||||
|
||||
|
||||
def test_list_threads(admin_user: DATestUser | None) -> None:
|
||||
# Create a few threads
|
||||
for _ in range(3):
|
||||
requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the threads
|
||||
list_response = requests.get(
|
||||
THREADS_URL,
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert list_response.status_code == 200
|
||||
|
||||
response_json = list_response.json()
|
||||
assert "sessions" in response_json
|
||||
assert len(response_json["sessions"]) >= 3
|
||||
|
||||
for session in response_json["sessions"]:
|
||||
assert "id" in session
|
||||
assert "name" in session
|
||||
assert "persona_id" in session
|
||||
assert "time_created" in session
|
||||
assert "shared_status" in session
|
||||
assert "folder_id" in session
|
||||
assert "current_alternate_model" in session
|
||||
|
||||
# Validate UUID
|
||||
UUID(session["id"])
|
||||
|
||||
# Validate shared_status
|
||||
assert session["shared_status"] in [
|
||||
status.value for status in ChatSessionSharedStatus
|
||||
]
|
||||
10
ct.yaml
10
ct.yaml
@@ -1,12 +1,18 @@
|
||||
# See https://github.com/helm/chart-testing#configuration
|
||||
|
||||
# still have to specify this on the command line for list-changed
|
||||
chart-dirs:
|
||||
- deployment/helm/charts
|
||||
|
||||
# must be kept in sync with Chart.yaml
|
||||
chart-repos:
|
||||
- vespa=https://unoplat.github.io/vespa-helm-charts
|
||||
- vespa=https://danswer-ai.github.io/vespa-helm-charts
|
||||
- postgresql=https://charts.bitnami.com/bitnami
|
||||
|
||||
helm-extra-args: --timeout 600s
|
||||
helm-extra-args: --debug --timeout 600s
|
||||
|
||||
# nginx appears to not work on kind, likely due to lack of loadbalancer support
|
||||
# helm-extra-set-args also only works on the command line, not in this yaml
|
||||
# helm-extra-set-args: --set=nginx.enabled=false
|
||||
|
||||
validate-maintainers: false
|
||||
|
||||
@@ -9,11 +9,12 @@ spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-indexing
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 30
|
||||
maxReplicaCount: 10
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing
|
||||
@@ -21,10 +22,10 @@ spec:
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:2
|
||||
@@ -35,6 +36,7 @@ spec:
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:3
|
||||
@@ -42,12 +44,3 @@ spec:
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: cpu
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
|
||||
- type: memory
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
|
||||
@@ -8,11 +8,12 @@ metadata:
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-light
|
||||
minReplicaCount: 5
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 20
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync
|
||||
@@ -22,6 +23,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:2
|
||||
@@ -31,6 +33,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:3
|
||||
@@ -40,6 +43,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion
|
||||
@@ -49,6 +53,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion:2
|
||||
|
||||
@@ -15,6 +15,7 @@ spec:
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery
|
||||
@@ -25,6 +26,7 @@ spec:
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:1
|
||||
@@ -34,6 +36,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:2
|
||||
@@ -43,6 +46,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:3
|
||||
@@ -52,6 +56,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks
|
||||
@@ -61,6 +66,7 @@ spec:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks:2
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: indexing-model-server-scaledobject
|
||||
namespace: danswer
|
||||
labels:
|
||||
app: indexing-model-server
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: indexing-model-server-deployment
|
||||
pollingInterval: 15 # Check every 15 seconds
|
||||
cooldownPeriod: 30 # Wait 30 seconds before scaling down
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 14
|
||||
triggers:
|
||||
- type: cpu
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
@@ -5,5 +5,5 @@ metadata:
|
||||
namespace: danswer
|
||||
type: Opaque
|
||||
data:
|
||||
host: { base64 encoded host here }
|
||||
password: { base64 encoded password here }
|
||||
host: { { base64-encoded-hostname } }
|
||||
password: { { base64-encoded-password } }
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
@@ -47,10 +47,10 @@ spec:
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "4Gi"
|
||||
memory: "1Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "8Gi"
|
||||
memory: "2Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
|
||||
@@ -14,8 +14,8 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
|
||||
@@ -3,13 +3,13 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 14.3.1
|
||||
- name: vespa
|
||||
repository: https://unoplat.github.io/vespa-helm-charts
|
||||
version: 0.2.3
|
||||
repository: https://danswer-ai.github.io/vespa-helm-charts
|
||||
version: 0.2.16
|
||||
- name: nginx
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 15.14.0
|
||||
- name: redis
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 20.1.0
|
||||
digest: sha256:fb42426c1d13667a4929d0d6a7d681bf08120e4a4eb1d15437e4ec70920be3f8
|
||||
generated: "2024-09-11T09:16:03.312328-07:00"
|
||||
digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232
|
||||
generated: "2024-11-07T09:39:30.17171-08:00"
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.danswer.ai/
|
||||
sources:
|
||||
- "https://github.com/danswer-ai/danswer"
|
||||
type: application
|
||||
version: 0.2.0
|
||||
version: 0.2.1
|
||||
appVersion: "latest"
|
||||
annotations:
|
||||
category: Productivity
|
||||
@@ -23,8 +23,8 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
condition: postgresql.enabled
|
||||
- name: vespa
|
||||
version: 0.2.3
|
||||
repository: https://unoplat.github.io/vespa-helm-charts
|
||||
version: 0.2.16
|
||||
repository: https://danswer-ai.github.io/vespa-helm-charts
|
||||
condition: vespa.enabled
|
||||
- name: nginx
|
||||
version: 15.14.0
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
data:
|
||||
INTERNAL_URL: "http://{{ include "danswer-stack.fullname" . }}-api-service:{{ .Values.api.service.port | default 8080 }}"
|
||||
POSTGRES_HOST: {{ .Release.Name }}-postgresql
|
||||
VESPA_HOST: "document-index-service"
|
||||
VESPA_HOST: da-vespa-0.vespa-service
|
||||
REDIS_HOST: {{ .Release.Name }}-redis-master
|
||||
MODEL_SERVER_HOST: "{{ include "danswer-stack.fullname" . }}-inference-model-service"
|
||||
INDEXING_MODEL_SERVER_HOST: "{{ include "danswer-stack.fullname" . }}-indexing-model-service"
|
||||
|
||||
@@ -11,5 +11,5 @@ spec:
|
||||
- name: wget
|
||||
image: busybox
|
||||
command: ['wget']
|
||||
args: ['{{ include "danswer-stack.fullname" . }}:{{ .Values.webserver.service.port }}']
|
||||
args: ['{{ include "danswer-stack.fullname" . }}-webserver:{{ .Values.webserver.service.port }}']
|
||||
restartPolicy: Never
|
||||
|
||||
125
examples/assistants-api/topics_analyzer.py
Normal file
125
examples/assistants-api/topics_analyzer.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
ASSISTANT_NAME = "Topic Analyzer"
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that analyzes topics by searching through available \
|
||||
documents and providing insights. These available documents come from common \
|
||||
workplace tools like Slack, emails, Confluence, Google Drive, etc.
|
||||
|
||||
When analyzing a topic:
|
||||
1. Search for relevant information using the search tool
|
||||
2. Synthesize the findings into clear insights
|
||||
3. Highlight key trends, patterns, or notable developments
|
||||
4. Maintain objectivity and cite sources where relevant
|
||||
"""
|
||||
USER_PROMPT = """
|
||||
Please analyze and provide insights about this topic: {topic}.
|
||||
|
||||
IMPORTANT: do not mention things that are not relevant to the specified topic. \
|
||||
If there is no relevant information, just say "No relevant information found."
|
||||
"""
|
||||
|
||||
|
||||
def wait_on_run(client: OpenAI, run, thread):
|
||||
while run.status == "queued" or run.status == "in_progress":
|
||||
run = client.beta.threads.runs.retrieve(
|
||||
thread_id=thread.id,
|
||||
run_id=run.id,
|
||||
)
|
||||
time.sleep(0.5)
|
||||
return run
|
||||
|
||||
|
||||
def show_response(messages) -> None:
|
||||
# Get only the assistant's response text
|
||||
for message in messages.data[::-1]:
|
||||
if message.role == "assistant":
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
print(content.text)
|
||||
break
|
||||
|
||||
|
||||
def analyze_topics(topics: list[str]) -> None:
|
||||
openai_api_key = os.environ.get(
|
||||
"OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"
|
||||
)
|
||||
danswer_api_key = os.environ.get(
|
||||
"DANSWER_API_KEY", "<your Danswer API key if not set as env var>"
|
||||
)
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url="http://localhost:8080/openai-assistants",
|
||||
default_headers={
|
||||
"Authorization": f"Bearer {danswer_api_key}",
|
||||
},
|
||||
)
|
||||
|
||||
# Create an assistant if it doesn't exist
|
||||
try:
|
||||
assistants = client.beta.assistants.list(limit=100)
|
||||
# Find the Topic Analyzer assistant if it exists
|
||||
assistant = next((a for a in assistants.data if a.name == ASSISTANT_NAME))
|
||||
client.beta.assistants.delete(assistant.id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assistant = client.beta.assistants.create(
|
||||
name=ASSISTANT_NAME,
|
||||
instructions=SYSTEM_PROMPT,
|
||||
tools=[{"type": "SearchTool"}], # type: ignore
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Process each topic individually
|
||||
for topic in topics:
|
||||
thread = client.beta.threads.create()
|
||||
message = client.beta.threads.messages.create(
|
||||
thread_id=thread.id,
|
||||
role="user",
|
||||
content=USER_PROMPT.format(topic=topic),
|
||||
)
|
||||
|
||||
run = client.beta.threads.runs.create(
|
||||
thread_id=thread.id,
|
||||
assistant_id=assistant.id,
|
||||
tools=[
|
||||
{ # type: ignore
|
||||
"type": "SearchTool",
|
||||
"retrieval_details": {
|
||||
"run_search": "always",
|
||||
"filters": {
|
||||
"time_cutoff": str(
|
||||
datetime.now(timezone.utc) - timedelta(days=7)
|
||||
)
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
run = wait_on_run(client, run, thread)
|
||||
messages = client.beta.threads.messages.list(
|
||||
thread_id=thread.id, order="asc", after=message.id
|
||||
)
|
||||
print(f"\nAnalysis for topic: {topic}")
|
||||
print("-" * 40)
|
||||
show_response(messages)
|
||||
print()
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Analyze specific topics")
|
||||
parser.add_argument("topics", nargs="+", help="Topics to analyze (one or more)")
|
||||
|
||||
args = parser.parse_args()
|
||||
analyze_topics(args.topics)
|
||||
@@ -194,7 +194,9 @@ function ConnectorRow({
|
||||
return (
|
||||
<TableRow
|
||||
className={`hover:bg-hover-light ${
|
||||
invisible ? "invisible !h-0 !-mb-10" : "!border !border-border"
|
||||
invisible
|
||||
? "invisible !h-0 !-mb-10 !border-none"
|
||||
: "!border !border-border"
|
||||
} w-full cursor-pointer relative `}
|
||||
onClick={() => {
|
||||
router.push(`/admin/connector/${ccPairsIndexingStatus.cc_pair_id}`);
|
||||
@@ -434,7 +436,6 @@ export function CCPairIndexingStatusTable({
|
||||
{!shouldExpand ? "Collapse All" : "Expand All"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<TableBody>
|
||||
{sortedSources
|
||||
.filter(
|
||||
@@ -454,14 +455,12 @@ export function CCPairIndexingStatusTable({
|
||||
return (
|
||||
<React.Fragment key={ind}>
|
||||
<br className="mt-4" />
|
||||
|
||||
<SummaryRow
|
||||
source={source}
|
||||
summary={groupSummaries[source]}
|
||||
isOpen={connectorsToggled[source] || false}
|
||||
onToggle={() => toggleSource(source)}
|
||||
/>
|
||||
|
||||
{connectorsToggled[source] && (
|
||||
<>
|
||||
<TableRow className="border border-border">
|
||||
|
||||
@@ -188,7 +188,7 @@ const AddUserButton = ({
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<Button className="w-fit" onClick={() => setModal(true)}>
|
||||
<Button className="my-auto w-fit" onClick={() => setModal(true)}>
|
||||
<div className="flex">
|
||||
<FiPlusSquare className="my-auto mr-2" />
|
||||
Invite Users
|
||||
|
||||
@@ -4,7 +4,6 @@ import { fetchChatData } from "@/lib/chat/fetchChatData";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { redirect } from "next/navigation";
|
||||
import WrappedAssistantsGallery from "./WrappedAssistantsGallery";
|
||||
import { AssistantsProvider } from "@/components/context/AssistantsContext";
|
||||
import { cookies } from "next/headers";
|
||||
|
||||
export default async function GalleryPage(props: {
|
||||
|
||||
@@ -257,11 +257,8 @@ export function ChatPage({
|
||||
|
||||
const noAssistants = liveAssistant == null || liveAssistant == undefined;
|
||||
|
||||
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
|
||||
useEffect(() => {
|
||||
if (!loadedIdSessionRef.current && !currentPersonaId) {
|
||||
return;
|
||||
}
|
||||
|
||||
const personaDefault = getLLMProviderOverrideForPersona(
|
||||
liveAssistant,
|
||||
llmProviders
|
||||
|
||||
@@ -420,7 +420,7 @@ export function ClientLayout({
|
||||
<div className="fixed bg-background left-0 gap-x-4 mb-8 px-4 py-2 w-full items-center flex justify-end">
|
||||
<UserDropdown />
|
||||
</div>
|
||||
<div className="pt-20 flex overflow-y-auto h-full px-4 md:px-12">
|
||||
<div className="pt-20 flex overflow-y-auto overflow-x-hidden h-full px-4 md:px-12">
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -12,7 +12,7 @@ const buttonVariants = cva(
|
||||
success:
|
||||
"bg-green-100 text-green-600 hover:bg-green-500/90 dark:bg-blue-500 dark:text-neutral-50 dark:hover:bg-green-900/90",
|
||||
"success-reverse":
|
||||
"bg-green-500 text-white hover:bg-green-600/90 dark:bg-neutral-50 dark:text-blue-500 dark:hover:bg-green-100/90",
|
||||
"bg-green-500 text-inverted hover:bg-green-600/90 dark:bg-neutral-50 dark:text-blue-500 dark:hover:bg-green-100/90",
|
||||
|
||||
default:
|
||||
"bg-neutral-900 border-border text-neutral-50 hover:bg-neutral-900/90 dark:bg-neutral-50 dark:text-neutral-900 dark:hover:bg-neutral-50/90",
|
||||
@@ -38,7 +38,7 @@ const buttonVariants = cva(
|
||||
"link-reverse":
|
||||
"text-neutral-50 underline-offset-4 hover:underline dark:text-neutral-900",
|
||||
submit:
|
||||
"bg-green-500 text-green-100 hover:bg-green-600/90 dark:bg-neutral-50 dark:text-blue-500 dark:hover:bg-green-100/90",
|
||||
"bg-green-500 text-inverted hover:bg-green-600/90 dark:bg-neutral-50 dark:text-blue-500 dark:hover:bg-green-100/90",
|
||||
|
||||
// "bg-blue-600 text-neutral-50 hover:bg-blue-600/80 dark:bg-blue-600 dark:text-neutral-50 dark:hover:bg-blue-600/90",
|
||||
"submit-reverse":
|
||||
|
||||
Reference in New Issue
Block a user