Compare commits

...

6 Commits

Author SHA1 Message Date
Chris Weaver
73af85b524 Add model updates (#3736)
* Add model updates

* Support specific format + add ut

* Small fix

* Actually add test

* Also update display models
2025-01-21 14:43:39 -08:00
pablonyx
151c6b526a Allow all LLMs for image generation assistants (#3732) 2025-01-21 13:27:35 -08:00
pablonyx
e5770e35d8 Update llm override hook (#3691) 2025-01-16 17:37:44 -08:00
Weves
fd25753013 Don't pass in tools for non-openai models 2025-01-13 13:03:42 -08:00
Weves
88aad7f411 bump litellm + openai version 2025-01-10 17:33:39 -08:00
Weves
5ca4efc24f Add tests for some LLM provider endpoints + small logic change to ensure that display_model_names is not empty 2025-01-10 11:15:51 -08:00
25 changed files with 812 additions and 179 deletions

View File

@@ -265,5 +265,6 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.llm_model_update",
]
)

View File

@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Any
from danswer.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from danswer.configs.constants import DanswerCeleryPriority
@@ -55,6 +56,19 @@ tasks_to_schedule = [
},
]
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": "check_for_llm_model_update",
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": DanswerCeleryPriority.LOW,
},
}
)
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -0,0 +1,120 @@
from typing import Any
import requests
from celery import shared_task
from celery import Task
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import LLMProvider
def _process_model_list_response(model_list_json: Any) -> list[str]:
# Handle case where response is wrapped in a "data" field
if isinstance(model_list_json, dict):
if "data" in model_list_json:
model_list_json = model_list_json["data"]
elif "models" in model_list_json:
model_list_json = model_list_json["models"]
else:
raise ValueError(
"Invalid response from API - expected dict with 'data' or "
f"'models' field, got {type(model_list_json)}"
)
if not isinstance(model_list_json, list):
raise ValueError(
f"Invalid response from API - expected list, got {type(model_list_json)}"
)
# Handle both string list and object list cases
model_names: list[str] = []
for item in model_list_json:
if isinstance(item, str):
model_names.append(item)
elif isinstance(item, dict):
if "model_name" in item:
model_names.append(item["model_name"])
elif "id" in item:
model_names.append(item["id"])
else:
raise ValueError(
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
)
else:
raise ValueError(
f"Invalid item in model list - expected string or dict, got {type(item)}"
)
return model_names
@shared_task(
name="check_for_llm_model_update",
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
if not LLM_MODEL_UPDATE_API_URL:
raise ValueError("LLM model update API URL not configured")
# First fetch the models from the API
try:
response = requests.get(LLM_MODEL_UPDATE_API_URL)
response.raise_for_status()
available_models = _process_model_list_response(response.json())
task_logger.info(f"Found available models: {available_models}")
except Exception:
task_logger.exception("Failed to fetch models from API.")
return None
# Then update the database with the fetched models
with get_session_with_tenant(tenant_id) as db_session:
# Get the default LLM provider
default_provider = (
db_session.query(LLMProvider)
.filter(LLMProvider.is_default_provider.is_(True))
.first()
)
if not default_provider:
task_logger.warning("No default LLM provider found")
return None
# log change if any
old_models = set(default_provider.model_names or [])
new_models = set(available_models)
added_models = new_models - old_models
removed_models = old_models - new_models
if added_models:
task_logger.info(f"Adding models: {sorted(added_models)}")
if removed_models:
task_logger.info(f"Removing models: {sorted(removed_models)}")
# Update the provider's model list
default_provider.model_names = available_models
default_provider.display_model_names = available_models
# if the default model is no longer available, set it to the first model in the list
if default_provider.default_model_name not in available_models:
task_logger.info(
f"Default model {default_provider.default_model_name} not "
f"available, setting to first model in list: {available_models[0]}"
)
default_provider.default_model_name = available_models[0]
if default_provider.fast_default_model_name not in available_models:
task_logger.info(
f"Fast default model {default_provider.fast_default_model_name} "
f"not available, setting to first model in list: {available_models[0]}"
)
default_provider.fast_default_model_name = available_models[0]
db_session.commit()
if added_models or removed_models:
task_logger.info("Updated model list for default provider.")
return True

View File

@@ -468,6 +468,8 @@ AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY")
AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
# LLM Model Update API endpoint
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"

View File

@@ -234,7 +234,12 @@ class Answer:
# DEBUG: good breakpoint
stream = self.llm.stream(
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tools=(
[tool.tool_definition() for tool in current_llm_call.tools]
if self.using_tool_calling_llm
else None
)
or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use

View File

@@ -13,6 +13,7 @@ from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.llm.utils import model_supports_image_input
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
@@ -67,6 +68,7 @@ class AnswerPromptBuilder:
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
self.llm_config = llm_config
self.llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
@@ -75,7 +77,13 @@ class AnswerPromptBuilder:
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
) = translate_history_to_basemessages(
message_history,
exclude_images=not model_supports_image_input(
self.llm_config.model_name,
self.llm_config.model_provider,
),
)
# for cases where like the QA flow where we want to condense the chat history
# into a single message rather than a sequence of User / Assistant messages
@@ -84,7 +92,10 @@ class AnswerPromptBuilder:
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []

View File

@@ -1,3 +1,4 @@
import copy
import io
import json
from collections.abc import Callable
@@ -105,6 +106,7 @@ def litellm_exception_to_error_msg(
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],
exclude_images: bool = False,
) -> BaseMessage:
files: list[InMemoryChatFile] = []
@@ -112,7 +114,9 @@ def translate_danswer_msg_to_langchain(
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
content = build_content_with_imgs(
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
@@ -125,10 +129,10 @@ def translate_danswer_msg_to_langchain(
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
history: list[ChatMessage] | list["PreviousMessage"], exclude_images: bool = False
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_danswer_msg_to_langchain(msg)
translate_danswer_msg_to_langchain(msg, exclude_images)
for msg in history
if msg.token_count != 0
]
@@ -188,6 +192,7 @@ def build_content_with_imgs(
files: list[InMemoryChatFile] | None = None,
img_urls: list[str] | None = None,
message_type: MessageType = MessageType.USER,
exclude_images: bool = False,
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
files = files or []
@@ -202,7 +207,7 @@ def build_content_with_imgs(
message_main_content = _build_content(message, files)
if not img_files and not img_urls:
if exclude_images or (not img_files and not img_urls):
return message_main_content
return cast(
@@ -383,6 +388,72 @@ def test_llm(llm: LLM) -> str | None:
return error_msg
def get_model_map() -> dict:
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
# NOTE: we could add additional models here in the future,
# but for now there is no point. Ollama allows the user to
# to specify their desired max context window, and it's
# unlikely to be standard across users even for the same model
# (it heavily depends on their hardware). For now, we'll just
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
# for model_name in [
# "llama3.2",
# "llama3.2:1b",
# "llama3.2:3b",
# "llama3.2:11b",
# "llama3.2:90b",
# ]:
# starting_map[f"ollama/{model_name}"] = {
# "max_tokens": 128000,
# "max_input_tokens": 128000,
# "max_output_tokens": 128000,
# }
return starting_map
def _strip_extra_provider_from_model_name(model_name: str) -> str:
return model_name.split("/")[1] if "/" in model_name else model_name
def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
stripped_model_name = _strip_extra_provider_from_model_name(model_name)
model_names = [
model_name,
_strip_extra_provider_from_model_name(model_name),
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(stripped_model_name),
]
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]
# First try all model names with provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj
# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj
return None
def get_llm_max_tokens(
model_map: dict,
model_name: str,
@@ -395,22 +466,11 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS
try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {model_provider}/{model_name}")
if not model_obj:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
if not model_obj:
model_name_split = model_name.split("/")
if len(model_name_split) > 1:
model_obj = model_map.get(model_name_split[1])
if model_obj:
logger.debug(f"Using model object for {model_name_split[1]}")
model_obj = _find_model_obj(
model_map,
model_provider,
model_name,
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
@@ -501,3 +561,23 @@ def get_max_input_tokens(
raise RuntimeError("No tokens for input for the LLM given settings")
return input_toks
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
model_map = get_model_map()
try:
model_obj = _find_model_obj(
model_map,
model_provider,
model_name,
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
)
return model_obj.get("supports_vision", False)
except Exception:
logger.exception(
f"Failed to get model object for {model_provider}/{model_name}"
)
return False

View File

@@ -144,19 +144,20 @@ def put_llm_provider(
detail=f"LLM Provider with name {llm_provider.name} already exists",
)
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.display_model_names is None:
llm_provider.display_model_names = []
if llm_provider.display_model_names is not None:
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)
if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name
not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(
llm_provider.fast_default_model_name
)
try:
return upsert_llm_provider(

View File

@@ -14,6 +14,7 @@ from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLM
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import message_to_string
from danswer.llm.utils import model_supports_image_input
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
@@ -293,6 +294,11 @@ class ImageGenerationTool(Tool):
build_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
img_urls=img_urls,
prompts=[img.revised_prompt for img in img_generation_response],
supports_image_input=model_supports_image_input(
prompt_builder.llm_config.model_name,
prompt_builder.llm_config.model_provider,
),
)
)

View File

@@ -6,16 +6,38 @@ from danswer.llm.utils import build_content_with_imgs
IMG_GENERATION_SUMMARY_PROMPT = """
You have just created the attached images in response to the following query: "{query}".
{img_urls}
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
"""
IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES = """
You have generated images based on the following query: "{query}".
The prompts used to generate these images were: {prompts}
Describe what the generated images depict based on the query and prompts provided.
Summarize the key elements and content of the images in a sentence or two. Be specific
about what was generated rather than speculating about what the images 'likely' contain.
"""
def build_image_generation_user_prompt(
query: str, img_urls: list[str] | None = None
query: str,
supports_image_input: bool,
img_urls: list[str] | None = None,
prompts: list[str] | None = None,
) -> HumanMessage:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
img_urls=img_urls,
if supports_image_input:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_SUMMARY_PROMPT.format(
query=query, img_urls=img_urls
).strip(),
)
)
else:
return HumanMessage(
content=IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES.format(
query=query, prompts=prompts
).strip()
)
)

View File

@@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.50.2
litellm==1.55.4
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
@@ -38,7 +38,7 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.52.2
openai==1.55.3
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5

View File

@@ -4,8 +4,12 @@ from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.reset import reset_all_multitenant
@@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None:
return None
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
except Exception:
pass
return None
@pytest.fixture
def reset_multitenant() -> None:
reset_all_multitenant()

View File

@@ -7,37 +7,12 @@ import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create("admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
)
)
except Exception:
pass
return None
@pytest.fixture
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)

View File

@@ -0,0 +1,120 @@
import uuid
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
"""Utility function to fetch an LLM provider by ID"""
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
assert response.status_code == 200
providers = response.json()
return next((p for p in providers if p["id"] == provider_id), None)
def test_create_llm_provider_without_display_model_names(
admin_user: DATestUser,
) -> None:
"""Test creating an LLM provider without specifying
display_model_names and verify it's null in response"""
# Create LLM provider without model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": str(uuid.uuid4()),
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
# Verify model_names is None/null
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
assert provider_data["display_model_names"] is None
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
"""Test updating an LLM provider's model_names"""
# First create provider without model_names
name = str(uuid.uuid4())
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": name,
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": [_DEFAULT_MODELS[0]],
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Update with model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
# Verify update
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
def test_delete_llm_provider(admin_user: DATestUser) -> None:
"""Test deleting an LLM provider"""
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": "test-provider-delete",
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Delete the provider
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify provider is deleted by checking it's not in the list
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is None

View File

@@ -0,0 +1,118 @@
import json
import os
from tempfile import NamedTemporaryFile
import requests
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Tool
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.settings.models import Settings
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
from ee.danswer.server.seeding import SeedConfiguration
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
def test_seeding(reset: None) -> None:
# Create admin user
admin_user: DATestUser = UserManager.create(name="admin_user")
# Create temporary files for testing
with (
NamedTemporaryFile(mode="w", suffix=".json") as tool_file,
NamedTemporaryFile(mode="w", suffix=".svg") as logo_file,
):
# Write test tool definition
tool_definition = {
"openapi": "3.0.0",
"info": {"title": "Test Tool", "version": "1.0.0"},
"paths": {},
}
json.dump(tool_definition, tool_file)
tool_file.flush()
# Write test logo
logo_file.write("<svg>Test Logo</svg>")
logo_file.flush()
# Create seed configuration
seed_config = SeedConfiguration(
llms=[
LLMProviderUpsertRequest(
model_name="test-model",
model_provider="test-provider",
api_key="test-key",
)
],
personas=[
CreatePersonaRequest(
name="Test Persona",
description="A test persona",
num_chunks=5,
)
],
settings=Settings(
enable_experimental_features=True,
),
enterprise_settings=EnterpriseSettings(
disable_source_filters=True,
),
seeded_logo_path=logo_file.name,
custom_tools=[
{
"name": "test-tool",
"description": "A test tool",
"definition_path": tool_file.name,
}
],
)
# Set environment variable with seed configuration
os.environ["ENV_SEED_CONFIGURATION"] = seed_config.model_dump_json()
# Verify seeded LLM
response = requests.get(
f"{API_SERVER_URL}/manage/llm-providers",
headers=admin_user.headers,
)
assert response.status_code == 200
llms = response.json()
assert any(llm["model_name"] == "test-model" for llm in llms)
# Verify seeded persona
response = requests.get(
f"{API_SERVER_URL}/personas",
headers=admin_user.headers,
)
assert response.status_code == 200
personas = response.json()
assert any(persona["name"] == "Test Persona" for persona in personas)
# Verify seeded tool
with get_session_context_manager() as db_session:
tools = db_session.query(Tool).all()
assert any(tool.name == "test-tool" for tool in tools)
# Verify settings
response = requests.get(
f"{API_SERVER_URL}/settings",
headers=admin_user.headers,
)
assert response.status_code == 200
settings = response.json()
assert settings["enable_experimental_features"] is True
# Verify enterprise settings
response = requests.get(
f"{API_SERVER_URL}/enterprise-settings",
headers=admin_user.headers,
)
assert response.status_code == 200
ee_settings = response.json()
assert ee_settings["disable_source_filters"] is True
# Clean up
os.environ.pop("ENV_SEED_CONFIGURATION", None)

View File

@@ -0,0 +1,92 @@
import pytest
from danswer.background.celery.tasks.llm_model_update.tasks import (
_process_model_list_response,
)
@pytest.mark.parametrize(
"input_data,expected_result,expected_error,error_match",
[
# Success cases
(
["gpt-4", "gpt-3.5-turbo", "claude-2"],
["gpt-4", "gpt-3.5-turbo", "claude-2"],
None,
None,
),
(
[
{"model_name": "gpt-4", "other_field": "value"},
{"model_name": "gpt-3.5-turbo", "other_field": "value"},
],
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
(
[
{"id": "gpt-4", "other_field": "value"},
{"id": "gpt-3.5-turbo", "other_field": "value"},
],
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
(
{"data": ["gpt-4", "gpt-3.5-turbo"]},
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
(
{"models": ["gpt-4", "gpt-3.5-turbo"]},
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
(
{"models": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]},
["gpt-4", "gpt-3.5-turbo"],
None,
None,
),
# Error cases
(
"not a list",
None,
ValueError,
"Invalid response from API - expected list",
),
(
{"wrong_field": []},
None,
ValueError,
"Invalid response from API - expected dict with 'data' or 'models' field",
),
(
[{"wrong_field": "value"}],
None,
ValueError,
"Invalid item in model list - expected dict with model_name or id",
),
(
[42],
None,
ValueError,
"Invalid item in model list - expected string or dict",
),
],
)
def test_process_model_list_response(
input_data: dict | list,
expected_result: list[str] | None,
expected_error: type[Exception] | None,
error_match: str | None,
) -> None:
if expected_error:
with pytest.raises(expected_error, match=error_match):
_process_model_list_response(input_data)
else:
result = _process_model_list_response(input_data)
assert result == expected_result

View File

@@ -315,26 +315,10 @@ export function AssistantEditor({
let enabledTools = Object.keys(values.enabled_tools_map)
.map((toolId) => Number(toolId))
.filter((toolId) => values.enabled_tools_map[toolId]);
const searchToolEnabled = searchTool
? enabledTools.includes(searchTool.id)
: false;
const imageGenerationToolEnabled = imageGenerationTool
? enabledTools.includes(imageGenerationTool.id)
: false;
if (imageGenerationToolEnabled) {
if (
// model must support image input for image generation
// to work
!checkLLMSupportsImageInput(
values.llm_model_version_override || defaultModelName || ""
)
) {
enabledTools = enabledTools.filter(
(toolId) => toolId !== imageGenerationTool!.id
);
}
}
// if disable_retrieval is set, set num_chunks to 0
// to tell the backend to not fetch any documents
@@ -743,7 +727,6 @@ export function AssistantEditor({
<TooltipTrigger asChild>
<div
className={`w-fit ${
!currentLLMSupportsImageOutput ||
!isImageGenerationAvailable
? "opacity-70 cursor-not-allowed"
: ""
@@ -756,30 +739,17 @@ export function AssistantEditor({
onChange={() => {
toggleToolInValues(imageGenerationTool.id);
}}
disabled={
!currentLLMSupportsImageOutput ||
!isImageGenerationAvailable
}
disabled={!isImageGenerationAvailable}
/>
</div>
</TooltipTrigger>
{!currentLLMSupportsImageOutput ? (
{!isImageGenerationAvailable && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
To use Image Generation, select GPT-4o or another
image compatible model as the default model for
this Assistant.
Image Generation requires an OpenAI or Azure Dalle
configuration.
</p>
</TooltipContent>
) : (
!isImageGenerationAvailable && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
Image Generation requires an OpenAI or Azure
Dalle configuration.
</p>
</TooltipContent>
)
)}
</Tooltip>
</TooltipProvider>

View File

@@ -142,6 +142,8 @@ export function CustomLLMProviderUpdateForm({
},
body: JSON.stringify({
...values,
// For custom llm providers, all model names are displayed
display_model_names: values.model_names,
custom_config: customConfigProcessing(values.custom_config_list),
}),
});

View File

@@ -50,6 +50,7 @@ import {
useContext,
useEffect,
useLayoutEffect,
useMemo,
useRef,
useState,
} from "react";
@@ -243,9 +244,9 @@ export function ChatPage({
};
const llmOverrideManager = useLlmOverride(
modelVersionFromSearchParams || (user?.preferences.default_model ?? null),
selectedChatSession,
defaultTemperature
llmProviders,
user?.preferences.default_model,
selectedChatSession
);
const [alternativeAssistant, setAlternativeAssistant] =
@@ -269,20 +270,21 @@ export function ChatPage({
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
useEffect(() => {
if (noAssistants) return;
const personaDefault = getLLMProviderOverrideForPersona(
liveAssistant,
llmProviders
);
if (personaDefault) {
llmOverrideManager.setLlmOverride(personaDefault);
llmOverrideManager.updateLLMOverride(personaDefault);
} else if (user?.preferences.default_model) {
llmOverrideManager.setLlmOverride(
llmOverrideManager.updateLLMOverride(
destructureValue(user?.preferences.default_model)
);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [liveAssistant, llmProviders, user?.preferences.default_model]);
}, [liveAssistant, user?.preferences.default_model]);
const stopGenerating = () => {
const currentSession = currentSessionId();
@@ -370,7 +372,7 @@ export function ChatPage({
// reset LLM overrides (based on chat session!)
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
llmOverrideManager.setTemperature(null);
llmOverrideManager.updateTemperature(null);
// remove uploaded files
setCurrentMessageFiles([]);
@@ -1514,7 +1516,7 @@ export function ChatPage({
setPopup({
type: "error",
message:
"The current Assistant does not support image input. Please select an assistant with Vision support.",
"The current model does not support image input. Please select a model with Vision support.",
});
return;
}
@@ -1722,6 +1724,14 @@ export function ChatPage({
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [messageHistory]);
const imageFileInMessageHistory = useMemo(() => {
return messageHistory
.filter((message) => message.type === "user")
.some((message) =>
message.files.some((file) => file.type === ChatFileType.IMAGE)
);
}, [messageHistory]);
const currentVisibleRange = visibleRange.get(currentSessionId()) || {
start: 0,
end: 0,
@@ -2435,6 +2445,9 @@ export function ChatPage({
</div>
)}
<ChatInputBar
sessionContainsImageFiles={
imageFileInMessageHistory
}
showConfigureAPIKey={() =>
setShowApiKeyModal(true)
}

View File

@@ -125,9 +125,9 @@ export default function RegenerateOption({
onHoverChange: (isHovered: boolean) => void;
onDropdownVisibleChange: (isVisible: boolean) => void;
}) {
const llmOverrideManager = useLlmOverride();
const { llmProviders } = useChatContext();
const llmOverrideManager = useLlmOverride(llmProviders);
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
const llmOptionsByProvider: {

View File

@@ -68,6 +68,7 @@ export function ChatInputBar({
alternativeAssistant,
chatSessionId,
inputPrompts,
sessionContainsImageFiles,
}: {
showConfigureAPIKey: () => void;
openModelSettings: () => void;
@@ -90,6 +91,7 @@ export function ChatInputBar({
handleFileUpload: (files: File[]) => void;
textAreaRef: React.RefObject<HTMLTextAreaElement>;
chatSessionId?: string;
sessionContainsImageFiles: boolean;
}) {
useEffect(() => {
const textarea = textAreaRef.current;
@@ -558,7 +560,6 @@ export function ChatInputBar({
tab
content={(close, ref) => (
<LlmTab
currentAssistant={alternativeAssistant || selectedAssistant}
openModelSettings={openModelSettings}
currentLlm={
llmOverrideManager.llmOverride.modelName ||
@@ -572,6 +573,7 @@ export function ChatInputBar({
ref={ref}
llmOverrideManager={llmOverrideManager}
chatSessionId={chatSessionId}
imageFilesPresent={sessionContainsImageFiles}
/>
)}
position="top"

View File

@@ -144,3 +144,6 @@ export interface StreamingError {
error: string;
stack_trace: string;
}
export const isAnthropic = (provider: string, modelName: string) =>
provider === "anthropic" || modelName.toLowerCase().includes("claude");

View File

@@ -8,7 +8,6 @@ import { destructureValue } from "@/lib/llm/utils";
import { updateModelOverrideForChatSession } from "../../lib";
import { GearIcon } from "@/components/icons/icons";
import { LlmList } from "@/components/llm/LLMList";
import { checkPersonaRequiresImageGeneration } from "@/app/admin/assistants/lib";
interface LlmTabProps {
llmOverrideManager: LlmOverrideManager;
@@ -16,7 +15,7 @@ interface LlmTabProps {
openModelSettings: () => void;
chatSessionId?: string;
close: () => void;
currentAssistant: Persona;
imageFilesPresent: boolean;
}
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
@@ -27,15 +26,13 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
currentLlm,
close,
openModelSettings,
currentAssistant,
imageFilesPresent,
},
ref
) => {
const requiresImageGeneration =
checkPersonaRequiresImageGeneration(currentAssistant);
const { llmProviders } = useChatContext();
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
const { updateLLMOverride, temperature, updateTemperature } =
llmOverrideManager;
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
const [localTemperature, setLocalTemperature] = useState<number>(
temperature || 0
@@ -43,11 +40,11 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
const debouncedSetTemperature = useCallback(
(value: number) => {
const debouncedFunction = debounce((value: number) => {
setTemperature(value);
updateTemperature(value);
}, 300);
return debouncedFunction(value);
},
[setTemperature]
[updateTemperature]
);
const handleTemperatureChange = (value: number) => {
@@ -69,14 +66,14 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
</button>
</div>
<LlmList
requiresImageGeneration={requiresImageGeneration}
llmProviders={llmProviders}
currentLlm={currentLlm}
imageFilesPresent={imageFilesPresent}
onSelect={(value: string | null) => {
if (value == null) {
return;
}
setLlmOverride(destructureValue(value));
updateLLMOverride(destructureValue(value));
if (chatSessionId) {
updateModelOverrideForChatSession(chatSessionId, value as string);
}

View File

@@ -6,6 +6,14 @@ import {
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { FiAlertTriangle } from "react-icons/fi";
interface LlmListProps {
llmProviders: LLMProviderDescriptor[];
currentLlm: string;
@@ -13,7 +21,7 @@ interface LlmListProps {
userDefault?: string | null;
scrollable?: boolean;
hideProviderIcon?: boolean;
requiresImageGeneration?: boolean;
imageFilesPresent?: boolean;
}
export const LlmList: React.FC<LlmListProps> = ({
@@ -22,7 +30,7 @@ export const LlmList: React.FC<LlmListProps> = ({
onSelect,
userDefault,
scrollable,
requiresImageGeneration,
imageFilesPresent,
}) => {
const llmOptionsByProvider: {
[provider: string]: {
@@ -31,6 +39,7 @@ export const LlmList: React.FC<LlmListProps> = ({
icon: React.FC<{ size?: number; className?: string }>;
}[];
} = {};
const uniqueModelNames = new Set<string>();
llmProviders.forEach((llmProvider) => {
@@ -62,7 +71,9 @@ export const LlmList: React.FC<LlmListProps> = ({
return (
<div
className={`${scrollable ? "max-h-[200px] include-scrollbar" : "max-h-[300px]"} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
className={`${
scrollable ? "max-h-[200px] include-scrollbar" : "max-h-[300px]"
} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
>
{userDefault && (
<button
@@ -79,25 +90,36 @@ export const LlmList: React.FC<LlmListProps> = ({
</button>
)}
{llmOptions.map(({ name, icon, value }, index) => {
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
return (
<button
type="button"
key={index}
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
currentLlm == name
? "bg-background-200"
: "bg-background hover:bg-background-100"
} text-left rounded`}
onClick={() => onSelect(value)}
>
{icon({ size: 16 })}
{getDisplayNameForModel(name)}
</button>
);
}
})}
{llmOptions.map(({ name, icon, value }, index) => (
<button
type="button"
key={index}
className={`w-full py-1.5 flex items-center justify-start gap-x-2 px-2 text-sm ${
currentLlm == name
? "bg-background-200"
: "bg-background hover:bg-background-100"
} text-left rounded`}
onClick={() => onSelect(value)}
>
{icon({ size: 16 })}
{getDisplayNameForModel(name)}
{imageFilesPresent && !checkLLMSupportsImageInput(name) && (
<TooltipProvider>
<Tooltip delayDuration={0}>
<TooltipTrigger className="my-auto flex ites-center ml-auto">
<FiAlertTriangle className="text-alert" size={16} />
</TooltipTrigger>
<TooltipContent>
<p className="text-xs">
This LLM is not vision-capable and cannot process image
files present in your chat session.
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
</button>
))}
</div>
);
};

View File

@@ -10,12 +10,13 @@ import { errorHandlingFetcher } from "./fetcher";
import { useContext, useEffect, useState } from "react";
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
import { SourceMetadata } from "./search/interfaces";
import { destructureValue } from "./llm/utils";
import { ChatSession } from "@/app/chat/interfaces";
import { destructureValue, structureValue } from "./llm/utils";
import { ChatSession, isAnthropic } from "@/app/chat/interfaces";
import { UsersResponse } from "./users/interfaces";
import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
const CREDENTIAL_URL = "/api/manage/admin/credential";
@@ -71,7 +72,9 @@ export const useConnectorCredentialIndexingStatus = (
getEditable = false
) => {
const { mutate } = useSWRConfig();
const url = `${INDEXING_STATUS_URL}${getEditable ? "?get_editable=true" : ""}`;
const url = `${INDEXING_STATUS_URL}${
getEditable ? "?get_editable=true" : ""
}`;
const swrResponse = useSWR<ConnectorIndexingStatus<any, any>[]>(
url,
errorHandlingFetcher,
@@ -153,76 +156,102 @@ export interface LlmOverride {
export interface LlmOverrideManager {
llmOverride: LlmOverride;
setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>;
updateLLMOverride: (newOverride: LlmOverride) => void;
globalDefault: LlmOverride;
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
temperature: number | null;
setTemperature: React.Dispatch<React.SetStateAction<number | null>>;
updateTemperature: (temperature: number | null) => void;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
}
export function useLlmOverride(
llmProviders: LLMProviderDescriptor[],
globalModel?: string | null,
currentChatSession?: ChatSession,
defaultTemperature?: number
): LlmOverrideManager {
const getValidLlmOverride = (
overrideModel: string | null | undefined
): LlmOverride => {
if (overrideModel) {
const model = destructureValue(overrideModel);
const provider = llmProviders.find(
(p) =>
p.model_names.includes(model.modelName) &&
p.provider === model.provider
);
if (provider) {
return { ...model, name: provider.name };
}
}
return { name: "", provider: "", modelName: "" };
};
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
globalModel != null
? destructureValue(globalModel)
: {
name: "",
provider: "",
modelName: "",
}
getValidLlmOverride(globalModel)
);
const updateLLMOverride = (newOverride: LlmOverride) => {
setLlmOverride(
getValidLlmOverride(
structureValue(
newOverride.name,
newOverride.provider,
newOverride.modelName
)
)
);
};
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
currentChatSession && currentChatSession.current_alternate_model
? destructureValue(currentChatSession.current_alternate_model)
: {
name: "",
provider: "",
modelName: "",
}
? getValidLlmOverride(currentChatSession.current_alternate_model)
: { name: "", provider: "", modelName: "" }
);
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
setLlmOverride(
chatSession && chatSession.current_alternate_model
? destructureValue(chatSession.current_alternate_model)
? getValidLlmOverride(chatSession.current_alternate_model)
: globalDefault
);
};
const [temperature, setTemperature] = useState<number | null>(
defaultTemperature != undefined ? defaultTemperature : 0
defaultTemperature !== undefined ? defaultTemperature : 0
);
useEffect(() => {
setGlobalDefault(
globalModel != null
? destructureValue(globalModel)
: {
name: "",
provider: "",
modelName: "",
}
);
}, [globalModel]);
setGlobalDefault(getValidLlmOverride(globalModel));
}, [globalModel, llmProviders]);
useEffect(() => {
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
}, [defaultTemperature]);
useEffect(() => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
setTemperature((prevTemp) => Math.min(prevTemp ?? 0, 1.0));
}
}, [llmOverride]);
const updateTemperature = (temperature: number | null) => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
setTemperature((prevTemp) => Math.min(temperature ?? 0, 1.0));
} else {
setTemperature(temperature);
}
};
return {
updateModelOverrideForChatSession,
llmOverride,
setLlmOverride,
updateLLMOverride,
globalDefault,
setGlobalDefault,
temperature,
setTemperature,
updateTemperature,
};
}
/*
EE Only APIs
*/