mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-15 12:42:39 +00:00
Compare commits
6 Commits
v1.0.0-clo
...
special-ho
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73af85b524 | ||
|
|
151c6b526a | ||
|
|
e5770e35d8 | ||
|
|
fd25753013 | ||
|
|
88aad7f411 | ||
|
|
5ca4efc24f |
@@ -265,5 +265,6 @@ celery_app.autodiscover_tasks(
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
"danswer.background.celery.tasks.llm_model_update",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
|
||||
@@ -55,6 +56,19 @@ tasks_to_schedule = [
|
||||
},
|
||||
]
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": "check_for_llm_model_update",
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": DanswerCeleryPriority.LOW,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import LLMProvider
|
||||
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict):
|
||||
if "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
elif "models" in model_list_json:
|
||||
model_list_json = model_list_json["models"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid response from API - expected dict with 'data' or "
|
||||
f"'models' field, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
# Handle both string list and object list cases
|
||||
model_names: list[str] = []
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict):
|
||||
if "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
elif "id" in item:
|
||||
model_names.append(item["id"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_llm_model_update",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
# First fetch the models from the API
|
||||
try:
|
||||
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
||||
response.raise_for_status()
|
||||
available_models = _process_model_list_response(response.json())
|
||||
task_logger.info(f"Found available models: {available_models}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Failed to fetch models from API.")
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
.filter(LLMProvider.is_default_provider.is_(True))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
task_logger.warning("No default LLM provider found")
|
||||
return None
|
||||
|
||||
# log change if any
|
||||
old_models = set(default_provider.model_names or [])
|
||||
new_models = set(available_models)
|
||||
added_models = new_models - old_models
|
||||
removed_models = old_models - new_models
|
||||
|
||||
if added_models:
|
||||
task_logger.info(f"Adding models: {sorted(added_models)}")
|
||||
if removed_models:
|
||||
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
||||
|
||||
# Update the provider's model list
|
||||
default_provider.model_names = available_models
|
||||
default_provider.display_model_names = available_models
|
||||
# if the default model is no longer available, set it to the first model in the list
|
||||
if default_provider.default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Default model {default_provider.default_model_name} not "
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
if default_provider.fast_default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Fast default model {default_provider.fast_default_model_name} "
|
||||
f"not available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.fast_default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
task_logger.info("Updated model list for default provider.")
|
||||
|
||||
return True
|
||||
@@ -468,6 +468,8 @@ AZURE_DALLE_API_KEY = os.environ.get("AZURE_DALLE_API_KEY")
|
||||
AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
|
||||
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
|
||||
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
@@ -234,7 +234,12 @@ class Answer:
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tools=(
|
||||
[tool.tool_definition() for tool in current_llm_call.tools]
|
||||
if self.using_tool_calling_llm
|
||||
else None
|
||||
)
|
||||
or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
|
||||
@@ -13,6 +13,7 @@ from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.llm.utils import model_supports_image_input
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
@@ -67,6 +68,7 @@ class AnswerPromptBuilder:
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
self.llm_config = llm_config
|
||||
self.llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -75,7 +77,13 @@ class AnswerPromptBuilder:
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
) = translate_history_to_basemessages(
|
||||
message_history,
|
||||
exclude_images=not model_supports_image_input(
|
||||
self.llm_config.model_name,
|
||||
self.llm_config.model_provider,
|
||||
),
|
||||
)
|
||||
|
||||
# for cases where like the QA flow where we want to condense the chat history
|
||||
# into a single message rather than a sequence of User / Assistant messages
|
||||
@@ -84,7 +92,10 @@ class AnswerPromptBuilder:
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
check_message_tokens(
|
||||
user_message,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
@@ -105,6 +106,7 @@ def litellm_exception_to_error_msg(
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
exclude_images: bool = False,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
@@ -112,7 +114,9 @@ def translate_danswer_msg_to_langchain(
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
content = build_content_with_imgs(
|
||||
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
|
||||
)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
@@ -125,10 +129,10 @@ def translate_danswer_msg_to_langchain(
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
history: list[ChatMessage] | list["PreviousMessage"], exclude_images: bool = False
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
translate_danswer_msg_to_langchain(msg, exclude_images)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
@@ -188,6 +192,7 @@ def build_content_with_imgs(
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
img_urls: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
exclude_images: bool = False,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
|
||||
@@ -202,7 +207,7 @@ def build_content_with_imgs(
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
if not img_files and not img_urls:
|
||||
if exclude_images or (not img_files and not img_urls):
|
||||
return message_main_content
|
||||
|
||||
return cast(
|
||||
@@ -383,6 +388,72 @@ def test_llm(llm: LLM) -> str | None:
|
||||
return error_msg
|
||||
|
||||
|
||||
def get_model_map() -> dict:
|
||||
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
|
||||
|
||||
# NOTE: we could add additional models here in the future,
|
||||
# but for now there is no point. Ollama allows the user to
|
||||
# to specify their desired max context window, and it's
|
||||
# unlikely to be standard across users even for the same model
|
||||
# (it heavily depends on their hardware). For now, we'll just
|
||||
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
|
||||
# for model_name in [
|
||||
# "llama3.2",
|
||||
# "llama3.2:1b",
|
||||
# "llama3.2:3b",
|
||||
# "llama3.2:11b",
|
||||
# "llama3.2:90b",
|
||||
# ]:
|
||||
# starting_map[f"ollama/{model_name}"] = {
|
||||
# "max_tokens": 128000,
|
||||
# "max_input_tokens": 128000,
|
||||
# "max_output_tokens": 128000,
|
||||
# }
|
||||
|
||||
return starting_map
|
||||
|
||||
|
||||
def _strip_extra_provider_from_model_name(model_name: str) -> str:
|
||||
return model_name.split("/")[1] if "/" in model_name else model_name
|
||||
|
||||
|
||||
def _strip_colon_from_model_name(model_name: str) -> str:
|
||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||
|
||||
|
||||
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
|
||||
stripped_model_name = _strip_extra_provider_from_model_name(model_name)
|
||||
|
||||
model_names = [
|
||||
model_name,
|
||||
_strip_extra_provider_from_model_name(model_name),
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(stripped_model_name),
|
||||
]
|
||||
|
||||
# Filter out None values and deduplicate model names
|
||||
filtered_model_names = [name for name in model_names if name]
|
||||
|
||||
# First try all model names with provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(f"{provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {provider}/{model_name}")
|
||||
return model_obj
|
||||
|
||||
# Then try all model names without provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
return model_obj
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
@@ -395,22 +466,11 @@ def get_llm_max_tokens(
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_name_split = model_name.split("/")
|
||||
if len(model_name_split) > 1:
|
||||
model_obj = model_map.get(model_name_split[1])
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name_split[1]}")
|
||||
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
@@ -501,3 +561,23 @@ def get_max_input_tokens(
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
||||
return input_toks
|
||||
|
||||
|
||||
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
model_map = get_model_map()
|
||||
try:
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
)
|
||||
return model_obj.get("supports_vision", False)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to get model object for {model_provider}/{model_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -144,19 +144,20 @@ def put_llm_provider(
|
||||
detail=f"LLM Provider with name {llm_provider.name} already exists",
|
||||
)
|
||||
|
||||
# Ensure default_model_name and fast_default_model_name are in display_model_names
|
||||
# This is necessary for custom models and Bedrock/Azure models
|
||||
if llm_provider.display_model_names is None:
|
||||
llm_provider.display_model_names = []
|
||||
if llm_provider.display_model_names is not None:
|
||||
# Ensure default_model_name and fast_default_model_name are in display_model_names
|
||||
# This is necessary for custom models and Bedrock/Azure models
|
||||
if llm_provider.default_model_name not in llm_provider.display_model_names:
|
||||
llm_provider.display_model_names.append(llm_provider.default_model_name)
|
||||
|
||||
if llm_provider.default_model_name not in llm_provider.display_model_names:
|
||||
llm_provider.display_model_names.append(llm_provider.default_model_name)
|
||||
|
||||
if (
|
||||
llm_provider.fast_default_model_name
|
||||
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
|
||||
):
|
||||
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
|
||||
if (
|
||||
llm_provider.fast_default_model_name
|
||||
and llm_provider.fast_default_model_name
|
||||
not in llm_provider.display_model_names
|
||||
):
|
||||
llm_provider.display_model_names.append(
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.llm.utils import model_supports_image_input
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
@@ -293,6 +294,11 @@ class ImageGenerationTool(Tool):
|
||||
build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
img_urls=img_urls,
|
||||
prompts=[img.revised_prompt for img in img_generation_response],
|
||||
supports_image_input=model_supports_image_input(
|
||||
prompt_builder.llm_config.model_name,
|
||||
prompt_builder.llm_config.model_provider,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -6,16 +6,38 @@ from danswer.llm.utils import build_content_with_imgs
|
||||
IMG_GENERATION_SUMMARY_PROMPT = """
|
||||
You have just created the attached images in response to the following query: "{query}".
|
||||
|
||||
{img_urls}
|
||||
|
||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES = """
|
||||
You have generated images based on the following query: "{query}".
|
||||
The prompts used to generate these images were: {prompts}
|
||||
|
||||
Describe what the generated images depict based on the query and prompts provided.
|
||||
Summarize the key elements and content of the images in a sentence or two. Be specific
|
||||
about what was generated rather than speculating about what the images 'likely' contain.
|
||||
"""
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str, img_urls: list[str] | None = None
|
||||
query: str,
|
||||
supports_image_input: bool,
|
||||
img_urls: list[str] | None = None,
|
||||
prompts: list[str] | None = None,
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
img_urls=img_urls,
|
||||
if supports_image_input:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(
|
||||
query=query, img_urls=img_urls
|
||||
).strip(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
return HumanMessage(
|
||||
content=IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES.format(
|
||||
query=query, prompts=prompts
|
||||
).strip()
|
||||
)
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.50.2
|
||||
litellm==1.55.4
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
@@ -38,7 +38,7 @@ msal==1.28.0
|
||||
nltk==3.8.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.52.2
|
||||
openai==1.55.3
|
||||
openpyxl==3.1.2
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
|
||||
@@ -4,8 +4,12 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.reset import reset_all_multitenant
|
||||
@@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_multitenant() -> None:
|
||||
reset_all_multitenant()
|
||||
|
||||
@@ -7,37 +7,12 @@ import requests
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create("admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
|
||||
|
||||
|
||||
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
|
||||
"""Utility function to fetch an LLM provider by ID"""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
return next((p for p in providers if p["id"] == provider_id), None)
|
||||
|
||||
|
||||
def test_create_llm_provider_without_display_model_names(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test creating an LLM provider without specifying
|
||||
display_model_names and verify it's null in response"""
|
||||
# Create LLM provider without model_names
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": str(uuid.uuid4()),
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
|
||||
# Verify model_names is None/null
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
||||
assert provider_data["display_model_names"] is None
|
||||
|
||||
|
||||
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
|
||||
"""Test updating an LLM provider's model_names"""
|
||||
# First create provider without model_names
|
||||
name = str(uuid.uuid4())
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": name,
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": [_DEFAULT_MODELS[0]],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Update with model_names
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify update
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
|
||||
|
||||
def test_delete_llm_provider(admin_user: DATestUser) -> None:
|
||||
"""Test deleting an LLM provider"""
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-delete",
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Delete the provider
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify provider is deleted by checking it's not in the list
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is None
|
||||
118
backend/tests/integration/tests/seeding/test_seeding.py
Normal file
118
backend/tests/integration/tests/seeding/test_seeding.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import json
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Tool
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.models import Settings
|
||||
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.danswer.server.seeding import SeedConfiguration
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_seeding(reset: None) -> None:
|
||||
# Create admin user
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create temporary files for testing
|
||||
with (
|
||||
NamedTemporaryFile(mode="w", suffix=".json") as tool_file,
|
||||
NamedTemporaryFile(mode="w", suffix=".svg") as logo_file,
|
||||
):
|
||||
# Write test tool definition
|
||||
tool_definition = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test Tool", "version": "1.0.0"},
|
||||
"paths": {},
|
||||
}
|
||||
json.dump(tool_definition, tool_file)
|
||||
tool_file.flush()
|
||||
|
||||
# Write test logo
|
||||
logo_file.write("<svg>Test Logo</svg>")
|
||||
logo_file.flush()
|
||||
|
||||
# Create seed configuration
|
||||
seed_config = SeedConfiguration(
|
||||
llms=[
|
||||
LLMProviderUpsertRequest(
|
||||
model_name="test-model",
|
||||
model_provider="test-provider",
|
||||
api_key="test-key",
|
||||
)
|
||||
],
|
||||
personas=[
|
||||
CreatePersonaRequest(
|
||||
name="Test Persona",
|
||||
description="A test persona",
|
||||
num_chunks=5,
|
||||
)
|
||||
],
|
||||
settings=Settings(
|
||||
enable_experimental_features=True,
|
||||
),
|
||||
enterprise_settings=EnterpriseSettings(
|
||||
disable_source_filters=True,
|
||||
),
|
||||
seeded_logo_path=logo_file.name,
|
||||
custom_tools=[
|
||||
{
|
||||
"name": "test-tool",
|
||||
"description": "A test tool",
|
||||
"definition_path": tool_file.name,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Set environment variable with seed configuration
|
||||
os.environ["ENV_SEED_CONFIGURATION"] = seed_config.model_dump_json()
|
||||
|
||||
# Verify seeded LLM
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/llm-providers",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
llms = response.json()
|
||||
assert any(llm["model_name"] == "test-model" for llm in llms)
|
||||
|
||||
# Verify seeded persona
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/personas",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
personas = response.json()
|
||||
assert any(persona["name"] == "Test Persona" for persona in personas)
|
||||
|
||||
# Verify seeded tool
|
||||
with get_session_context_manager() as db_session:
|
||||
tools = db_session.query(Tool).all()
|
||||
assert any(tool.name == "test-tool" for tool in tools)
|
||||
|
||||
# Verify settings
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/settings",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
settings = response.json()
|
||||
assert settings["enable_experimental_features"] is True
|
||||
|
||||
# Verify enterprise settings
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/enterprise-settings",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
ee_settings = response.json()
|
||||
assert ee_settings["disable_source_filters"] is True
|
||||
|
||||
# Clean up
|
||||
os.environ.pop("ENV_SEED_CONFIGURATION", None)
|
||||
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
|
||||
from danswer.background.celery.tasks.llm_model_update.tasks import (
|
||||
_process_model_list_response,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_data,expected_result,expected_error,error_match",
|
||||
[
|
||||
# Success cases
|
||||
(
|
||||
["gpt-4", "gpt-3.5-turbo", "claude-2"],
|
||||
["gpt-4", "gpt-3.5-turbo", "claude-2"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"model_name": "gpt-4", "other_field": "value"},
|
||||
{"model_name": "gpt-3.5-turbo", "other_field": "value"},
|
||||
],
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"id": "gpt-4", "other_field": "value"},
|
||||
{"id": "gpt-3.5-turbo", "other_field": "value"},
|
||||
],
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"data": ["gpt-4", "gpt-3.5-turbo"]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"models": ["gpt-4", "gpt-3.5-turbo"]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"models": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Error cases
|
||||
(
|
||||
"not a list",
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid response from API - expected list",
|
||||
),
|
||||
(
|
||||
{"wrong_field": []},
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid response from API - expected dict with 'data' or 'models' field",
|
||||
),
|
||||
(
|
||||
[{"wrong_field": "value"}],
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid item in model list - expected dict with model_name or id",
|
||||
),
|
||||
(
|
||||
[42],
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid item in model list - expected string or dict",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_process_model_list_response(
|
||||
input_data: dict | list,
|
||||
expected_result: list[str] | None,
|
||||
expected_error: type[Exception] | None,
|
||||
error_match: str | None,
|
||||
) -> None:
|
||||
if expected_error:
|
||||
with pytest.raises(expected_error, match=error_match):
|
||||
_process_model_list_response(input_data)
|
||||
else:
|
||||
result = _process_model_list_response(input_data)
|
||||
assert result == expected_result
|
||||
@@ -315,26 +315,10 @@ export function AssistantEditor({
|
||||
let enabledTools = Object.keys(values.enabled_tools_map)
|
||||
.map((toolId) => Number(toolId))
|
||||
.filter((toolId) => values.enabled_tools_map[toolId]);
|
||||
|
||||
const searchToolEnabled = searchTool
|
||||
? enabledTools.includes(searchTool.id)
|
||||
: false;
|
||||
const imageGenerationToolEnabled = imageGenerationTool
|
||||
? enabledTools.includes(imageGenerationTool.id)
|
||||
: false;
|
||||
|
||||
if (imageGenerationToolEnabled) {
|
||||
if (
|
||||
// model must support image input for image generation
|
||||
// to work
|
||||
!checkLLMSupportsImageInput(
|
||||
values.llm_model_version_override || defaultModelName || ""
|
||||
)
|
||||
) {
|
||||
enabledTools = enabledTools.filter(
|
||||
(toolId) => toolId !== imageGenerationTool!.id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// if disable_retrieval is set, set num_chunks to 0
|
||||
// to tell the backend to not fetch any documents
|
||||
@@ -743,7 +727,6 @@ export function AssistantEditor({
|
||||
<TooltipTrigger asChild>
|
||||
<div
|
||||
className={`w-fit ${
|
||||
!currentLLMSupportsImageOutput ||
|
||||
!isImageGenerationAvailable
|
||||
? "opacity-70 cursor-not-allowed"
|
||||
: ""
|
||||
@@ -756,30 +739,17 @@ export function AssistantEditor({
|
||||
onChange={() => {
|
||||
toggleToolInValues(imageGenerationTool.id);
|
||||
}}
|
||||
disabled={
|
||||
!currentLLMSupportsImageOutput ||
|
||||
!isImageGenerationAvailable
|
||||
}
|
||||
disabled={!isImageGenerationAvailable}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{!currentLLMSupportsImageOutput ? (
|
||||
{!isImageGenerationAvailable && (
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
To use Image Generation, select GPT-4o or another
|
||||
image compatible model as the default model for
|
||||
this Assistant.
|
||||
Image Generation requires an OpenAI or Azure Dalle
|
||||
configuration.
|
||||
</p>
|
||||
</TooltipContent>
|
||||
) : (
|
||||
!isImageGenerationAvailable && (
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
Image Generation requires an OpenAI or Azure
|
||||
Dalle configuration.
|
||||
</p>
|
||||
</TooltipContent>
|
||||
)
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
|
||||
@@ -142,6 +142,8 @@ export function CustomLLMProviderUpdateForm({
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...values,
|
||||
// For custom llm providers, all model names are displayed
|
||||
display_model_names: values.model_names,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -50,6 +50,7 @@ import {
|
||||
useContext,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
@@ -243,9 +244,9 @@ export function ChatPage({
|
||||
};
|
||||
|
||||
const llmOverrideManager = useLlmOverride(
|
||||
modelVersionFromSearchParams || (user?.preferences.default_model ?? null),
|
||||
selectedChatSession,
|
||||
defaultTemperature
|
||||
llmProviders,
|
||||
user?.preferences.default_model,
|
||||
selectedChatSession
|
||||
);
|
||||
|
||||
const [alternativeAssistant, setAlternativeAssistant] =
|
||||
@@ -269,20 +270,21 @@ export function ChatPage({
|
||||
|
||||
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
|
||||
useEffect(() => {
|
||||
if (noAssistants) return;
|
||||
const personaDefault = getLLMProviderOverrideForPersona(
|
||||
liveAssistant,
|
||||
llmProviders
|
||||
);
|
||||
|
||||
if (personaDefault) {
|
||||
llmOverrideManager.setLlmOverride(personaDefault);
|
||||
llmOverrideManager.updateLLMOverride(personaDefault);
|
||||
} else if (user?.preferences.default_model) {
|
||||
llmOverrideManager.setLlmOverride(
|
||||
llmOverrideManager.updateLLMOverride(
|
||||
destructureValue(user?.preferences.default_model)
|
||||
);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [liveAssistant, llmProviders, user?.preferences.default_model]);
|
||||
}, [liveAssistant, user?.preferences.default_model]);
|
||||
|
||||
const stopGenerating = () => {
|
||||
const currentSession = currentSessionId();
|
||||
@@ -370,7 +372,7 @@ export function ChatPage({
|
||||
|
||||
// reset LLM overrides (based on chat session!)
|
||||
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
|
||||
llmOverrideManager.setTemperature(null);
|
||||
llmOverrideManager.updateTemperature(null);
|
||||
|
||||
// remove uploaded files
|
||||
setCurrentMessageFiles([]);
|
||||
@@ -1514,7 +1516,7 @@ export function ChatPage({
|
||||
setPopup({
|
||||
type: "error",
|
||||
message:
|
||||
"The current Assistant does not support image input. Please select an assistant with Vision support.",
|
||||
"The current model does not support image input. Please select a model with Vision support.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
@@ -1722,6 +1724,14 @@ export function ChatPage({
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [messageHistory]);
|
||||
|
||||
const imageFileInMessageHistory = useMemo(() => {
|
||||
return messageHistory
|
||||
.filter((message) => message.type === "user")
|
||||
.some((message) =>
|
||||
message.files.some((file) => file.type === ChatFileType.IMAGE)
|
||||
);
|
||||
}, [messageHistory]);
|
||||
|
||||
const currentVisibleRange = visibleRange.get(currentSessionId()) || {
|
||||
start: 0,
|
||||
end: 0,
|
||||
@@ -2435,6 +2445,9 @@ export function ChatPage({
|
||||
</div>
|
||||
)}
|
||||
<ChatInputBar
|
||||
sessionContainsImageFiles={
|
||||
imageFileInMessageHistory
|
||||
}
|
||||
showConfigureAPIKey={() =>
|
||||
setShowApiKeyModal(true)
|
||||
}
|
||||
|
||||
@@ -125,9 +125,9 @@ export default function RegenerateOption({
|
||||
onHoverChange: (isHovered: boolean) => void;
|
||||
onDropdownVisibleChange: (isVisible: boolean) => void;
|
||||
}) {
|
||||
const llmOverrideManager = useLlmOverride();
|
||||
|
||||
const { llmProviders } = useChatContext();
|
||||
const llmOverrideManager = useLlmOverride(llmProviders);
|
||||
|
||||
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
|
||||
|
||||
const llmOptionsByProvider: {
|
||||
|
||||
@@ -68,6 +68,7 @@ export function ChatInputBar({
|
||||
alternativeAssistant,
|
||||
chatSessionId,
|
||||
inputPrompts,
|
||||
sessionContainsImageFiles,
|
||||
}: {
|
||||
showConfigureAPIKey: () => void;
|
||||
openModelSettings: () => void;
|
||||
@@ -90,6 +91,7 @@ export function ChatInputBar({
|
||||
handleFileUpload: (files: File[]) => void;
|
||||
textAreaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
chatSessionId?: string;
|
||||
sessionContainsImageFiles: boolean;
|
||||
}) {
|
||||
useEffect(() => {
|
||||
const textarea = textAreaRef.current;
|
||||
@@ -558,7 +560,6 @@ export function ChatInputBar({
|
||||
tab
|
||||
content={(close, ref) => (
|
||||
<LlmTab
|
||||
currentAssistant={alternativeAssistant || selectedAssistant}
|
||||
openModelSettings={openModelSettings}
|
||||
currentLlm={
|
||||
llmOverrideManager.llmOverride.modelName ||
|
||||
@@ -572,6 +573,7 @@ export function ChatInputBar({
|
||||
ref={ref}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
chatSessionId={chatSessionId}
|
||||
imageFilesPresent={sessionContainsImageFiles}
|
||||
/>
|
||||
)}
|
||||
position="top"
|
||||
|
||||
@@ -144,3 +144,6 @@ export interface StreamingError {
|
||||
error: string;
|
||||
stack_trace: string;
|
||||
}
|
||||
|
||||
export const isAnthropic = (provider: string, modelName: string) =>
|
||||
provider === "anthropic" || modelName.toLowerCase().includes("claude");
|
||||
|
||||
@@ -8,7 +8,6 @@ import { destructureValue } from "@/lib/llm/utils";
|
||||
import { updateModelOverrideForChatSession } from "../../lib";
|
||||
import { GearIcon } from "@/components/icons/icons";
|
||||
import { LlmList } from "@/components/llm/LLMList";
|
||||
import { checkPersonaRequiresImageGeneration } from "@/app/admin/assistants/lib";
|
||||
|
||||
interface LlmTabProps {
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
@@ -16,7 +15,7 @@ interface LlmTabProps {
|
||||
openModelSettings: () => void;
|
||||
chatSessionId?: string;
|
||||
close: () => void;
|
||||
currentAssistant: Persona;
|
||||
imageFilesPresent: boolean;
|
||||
}
|
||||
|
||||
export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||
@@ -27,15 +26,13 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||
currentLlm,
|
||||
close,
|
||||
openModelSettings,
|
||||
currentAssistant,
|
||||
imageFilesPresent,
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
const requiresImageGeneration =
|
||||
checkPersonaRequiresImageGeneration(currentAssistant);
|
||||
|
||||
const { llmProviders } = useChatContext();
|
||||
const { setLlmOverride, temperature, setTemperature } = llmOverrideManager;
|
||||
const { updateLLMOverride, temperature, updateTemperature } =
|
||||
llmOverrideManager;
|
||||
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
||||
const [localTemperature, setLocalTemperature] = useState<number>(
|
||||
temperature || 0
|
||||
@@ -43,11 +40,11 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||
const debouncedSetTemperature = useCallback(
|
||||
(value: number) => {
|
||||
const debouncedFunction = debounce((value: number) => {
|
||||
setTemperature(value);
|
||||
updateTemperature(value);
|
||||
}, 300);
|
||||
return debouncedFunction(value);
|
||||
},
|
||||
[setTemperature]
|
||||
[updateTemperature]
|
||||
);
|
||||
|
||||
const handleTemperatureChange = (value: number) => {
|
||||
@@ -69,14 +66,14 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||
</button>
|
||||
</div>
|
||||
<LlmList
|
||||
requiresImageGeneration={requiresImageGeneration}
|
||||
llmProviders={llmProviders}
|
||||
currentLlm={currentLlm}
|
||||
imageFilesPresent={imageFilesPresent}
|
||||
onSelect={(value: string | null) => {
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
setLlmOverride(destructureValue(value));
|
||||
updateLLMOverride(destructureValue(value));
|
||||
if (chatSessionId) {
|
||||
updateModelOverrideForChatSession(chatSessionId, value as string);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,14 @@ import {
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { FiAlertTriangle } from "react-icons/fi";
|
||||
|
||||
interface LlmListProps {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
currentLlm: string;
|
||||
@@ -13,7 +21,7 @@ interface LlmListProps {
|
||||
userDefault?: string | null;
|
||||
scrollable?: boolean;
|
||||
hideProviderIcon?: boolean;
|
||||
requiresImageGeneration?: boolean;
|
||||
imageFilesPresent?: boolean;
|
||||
}
|
||||
|
||||
export const LlmList: React.FC<LlmListProps> = ({
|
||||
@@ -22,7 +30,7 @@ export const LlmList: React.FC<LlmListProps> = ({
|
||||
onSelect,
|
||||
userDefault,
|
||||
scrollable,
|
||||
requiresImageGeneration,
|
||||
imageFilesPresent,
|
||||
}) => {
|
||||
const llmOptionsByProvider: {
|
||||
[provider: string]: {
|
||||
@@ -31,6 +39,7 @@ export const LlmList: React.FC<LlmListProps> = ({
|
||||
icon: React.FC<{ size?: number; className?: string }>;
|
||||
}[];
|
||||
} = {};
|
||||
|
||||
const uniqueModelNames = new Set<string>();
|
||||
|
||||
llmProviders.forEach((llmProvider) => {
|
||||
@@ -62,7 +71,9 @@ export const LlmList: React.FC<LlmListProps> = ({
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${scrollable ? "max-h-[200px] include-scrollbar" : "max-h-[300px]"} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
|
||||
className={`${
|
||||
scrollable ? "max-h-[200px] include-scrollbar" : "max-h-[300px]"
|
||||
} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
|
||||
>
|
||||
{userDefault && (
|
||||
<button
|
||||
@@ -79,25 +90,36 @@ export const LlmList: React.FC<LlmListProps> = ({
|
||||
</button>
|
||||
)}
|
||||
|
||||
{llmOptions.map(({ name, icon, value }, index) => {
|
||||
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
key={index}
|
||||
className={`w-full py-1.5 flex gap-x-2 px-2 text-sm ${
|
||||
currentLlm == name
|
||||
? "bg-background-200"
|
||||
: "bg-background hover:bg-background-100"
|
||||
} text-left rounded`}
|
||||
onClick={() => onSelect(value)}
|
||||
>
|
||||
{icon({ size: 16 })}
|
||||
{getDisplayNameForModel(name)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
})}
|
||||
{llmOptions.map(({ name, icon, value }, index) => (
|
||||
<button
|
||||
type="button"
|
||||
key={index}
|
||||
className={`w-full py-1.5 flex items-center justify-start gap-x-2 px-2 text-sm ${
|
||||
currentLlm == name
|
||||
? "bg-background-200"
|
||||
: "bg-background hover:bg-background-100"
|
||||
} text-left rounded`}
|
||||
onClick={() => onSelect(value)}
|
||||
>
|
||||
{icon({ size: 16 })}
|
||||
{getDisplayNameForModel(name)}
|
||||
{imageFilesPresent && !checkLLMSupportsImageInput(name) && (
|
||||
<TooltipProvider>
|
||||
<Tooltip delayDuration={0}>
|
||||
<TooltipTrigger className="my-auto flex ites-center ml-auto">
|
||||
<FiAlertTriangle className="text-alert" size={16} />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<p className="text-xs">
|
||||
This LLM is not vision-capable and cannot process image
|
||||
files present in your chat session.
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -10,12 +10,13 @@ import { errorHandlingFetcher } from "./fetcher";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
|
||||
import { SourceMetadata } from "./search/interfaces";
|
||||
import { destructureValue } from "./llm/utils";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
import { destructureValue, structureValue } from "./llm/utils";
|
||||
import { ChatSession, isAnthropic } from "@/app/chat/interfaces";
|
||||
import { UsersResponse } from "./users/interfaces";
|
||||
import { Credential } from "./connectors/credentials";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
|
||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||
|
||||
@@ -71,7 +72,9 @@ export const useConnectorCredentialIndexingStatus = (
|
||||
getEditable = false
|
||||
) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const url = `${INDEXING_STATUS_URL}${getEditable ? "?get_editable=true" : ""}`;
|
||||
const url = `${INDEXING_STATUS_URL}${
|
||||
getEditable ? "?get_editable=true" : ""
|
||||
}`;
|
||||
const swrResponse = useSWR<ConnectorIndexingStatus<any, any>[]>(
|
||||
url,
|
||||
errorHandlingFetcher,
|
||||
@@ -153,76 +156,102 @@ export interface LlmOverride {
|
||||
|
||||
export interface LlmOverrideManager {
|
||||
llmOverride: LlmOverride;
|
||||
setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
||||
updateLLMOverride: (newOverride: LlmOverride) => void;
|
||||
globalDefault: LlmOverride;
|
||||
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
||||
temperature: number | null;
|
||||
setTemperature: React.Dispatch<React.SetStateAction<number | null>>;
|
||||
updateTemperature: (temperature: number | null) => void;
|
||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||
}
|
||||
export function useLlmOverride(
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
globalModel?: string | null,
|
||||
currentChatSession?: ChatSession,
|
||||
defaultTemperature?: number
|
||||
): LlmOverrideManager {
|
||||
const getValidLlmOverride = (
|
||||
overrideModel: string | null | undefined
|
||||
): LlmOverride => {
|
||||
if (overrideModel) {
|
||||
const model = destructureValue(overrideModel);
|
||||
const provider = llmProviders.find(
|
||||
(p) =>
|
||||
p.model_names.includes(model.modelName) &&
|
||||
p.provider === model.provider
|
||||
);
|
||||
if (provider) {
|
||||
return { ...model, name: provider.name };
|
||||
}
|
||||
}
|
||||
return { name: "", provider: "", modelName: "" };
|
||||
};
|
||||
|
||||
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
|
||||
globalModel != null
|
||||
? destructureValue(globalModel)
|
||||
: {
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
}
|
||||
getValidLlmOverride(globalModel)
|
||||
);
|
||||
const updateLLMOverride = (newOverride: LlmOverride) => {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(
|
||||
structureValue(
|
||||
newOverride.name,
|
||||
newOverride.provider,
|
||||
newOverride.modelName
|
||||
)
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
|
||||
currentChatSession && currentChatSession.current_alternate_model
|
||||
? destructureValue(currentChatSession.current_alternate_model)
|
||||
: {
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
}
|
||||
? getValidLlmOverride(currentChatSession.current_alternate_model)
|
||||
: { name: "", provider: "", modelName: "" }
|
||||
);
|
||||
|
||||
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
|
||||
setLlmOverride(
|
||||
chatSession && chatSession.current_alternate_model
|
||||
? destructureValue(chatSession.current_alternate_model)
|
||||
? getValidLlmOverride(chatSession.current_alternate_model)
|
||||
: globalDefault
|
||||
);
|
||||
};
|
||||
|
||||
const [temperature, setTemperature] = useState<number | null>(
|
||||
defaultTemperature != undefined ? defaultTemperature : 0
|
||||
defaultTemperature !== undefined ? defaultTemperature : 0
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setGlobalDefault(
|
||||
globalModel != null
|
||||
? destructureValue(globalModel)
|
||||
: {
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
}
|
||||
);
|
||||
}, [globalModel]);
|
||||
setGlobalDefault(getValidLlmOverride(globalModel));
|
||||
}, [globalModel, llmProviders]);
|
||||
|
||||
useEffect(() => {
|
||||
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
|
||||
}, [defaultTemperature]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
setTemperature((prevTemp) => Math.min(prevTemp ?? 0, 1.0));
|
||||
}
|
||||
}, [llmOverride]);
|
||||
|
||||
const updateTemperature = (temperature: number | null) => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
setTemperature((prevTemp) => Math.min(temperature ?? 0, 1.0));
|
||||
} else {
|
||||
setTemperature(temperature);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
updateModelOverrideForChatSession,
|
||||
llmOverride,
|
||||
setLlmOverride,
|
||||
updateLLMOverride,
|
||||
globalDefault,
|
||||
setGlobalDefault,
|
||||
temperature,
|
||||
setTemperature,
|
||||
updateTemperature,
|
||||
};
|
||||
}
|
||||
|
||||
/*
|
||||
EE Only APIs
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user