mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-17 15:36:43 +00:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9862b0ef59 | ||
|
|
8a7aeb2c59 | ||
|
|
648dcd1e47 | ||
|
|
f73796928c | ||
|
|
91101e8f2c | ||
|
|
44bb3ded44 | ||
|
|
493e3f23b8 | ||
|
|
031c1118bd | ||
|
|
b8b7702f28 | ||
|
|
ebb67aede9 | ||
|
|
340cd520eb | ||
|
|
b626ad232c | ||
|
|
f1ee9c12c0 | ||
|
|
378cbedaa1 | ||
|
|
f87e03b194 | ||
|
|
873636a095 | ||
|
|
efb194e067 | ||
|
|
3f7dfa7813 |
@@ -4,8 +4,6 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -635,7 +633,6 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1020,20 +1017,16 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
memory = update_memory_at_index(
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
memory = add_memory(
|
||||
persisted_memory_id = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -826,6 +826,12 @@ def translate_history_to_llm_format(
|
||||
base64_data = img_file.to_base64()
|
||||
image_url = f"data:{image_type};base64,{base64_data}"
|
||||
|
||||
content_parts.append(
|
||||
TextContentPart(
|
||||
type="text",
|
||||
text=f"[attached image — file_id: {img_file.file_id}]",
|
||||
)
|
||||
)
|
||||
image_part = ImageContentPart(
|
||||
type="image_url",
|
||||
image_url=ImageUrlDetail(
|
||||
|
||||
@@ -67,7 +67,6 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -1006,93 +1005,86 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
@@ -346,6 +347,25 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _safe_close_session(session: Session) -> None:
|
||||
"""Close a session, catching connection-closed errors during cleanup.
|
||||
|
||||
Long-running operations (e.g. multi-model LLM loops) can hold a session
|
||||
open for minutes. If the underlying connection is dropped by cloud
|
||||
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
|
||||
timeouts, etc.), the implicit rollback in Session.close() raises
|
||||
OperationalError or InterfaceError. Since the work is already complete,
|
||||
we log and move on — SQLAlchemy internally invalidates the connection
|
||||
for pool recycling.
|
||||
"""
|
||||
try:
|
||||
session.close()
|
||||
except DBAPIError:
|
||||
logger.warning(
|
||||
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -358,8 +378,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
session = Session(bind=engine, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
return
|
||||
|
||||
# Create connection with schema translation to handle querying the right schema
|
||||
@@ -367,8 +390,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
session = Session(bind=connection, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -83,47 +84,51 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
db_session: Session | None = None,
|
||||
) -> int:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
|
||||
Returns the id of the newly created Memory row.
|
||||
"""
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
db_session: Session | None = None,
|
||||
) -> int | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
Returns the id of the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
@@ -7,8 +7,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
@@ -22,6 +20,7 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
@@ -184,6 +183,14 @@ def generate_final_report(
|
||||
return has_reasoned
|
||||
|
||||
|
||||
def _get_research_agent_tool_id() -> int:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -193,7 +200,6 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt: str | None, # noqa: ARG001
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -717,6 +723,7 @@ def run_deep_research_llm_loop(
|
||||
simple_chat_history.append(assistant_with_tools)
|
||||
|
||||
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
|
||||
research_agent_tool_id = _get_research_agent_tool_id()
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
@@ -737,10 +744,7 @@ def run_deep_research_llm_loop(
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
tool_id=research_agent_tool_id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
|
||||
@@ -1516,6 +1516,10 @@
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-7": {
|
||||
"display_name": "Claude Opus 4.7",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
|
||||
@@ -46,6 +46,15 @@ ANTHROPIC_REASONING_EFFORT_BUDGET: dict[ReasoningEffort, int] = {
|
||||
ReasoningEffort.HIGH: 4096,
|
||||
}
|
||||
|
||||
# Newer Anthropic models (Claude Opus 4.7+) use adaptive thinking with
|
||||
# output_config.effort instead of thinking.type.enabled + budget_tokens.
|
||||
ANTHROPIC_ADAPTIVE_REASONING_EFFORT: dict[ReasoningEffort, str] = {
|
||||
ReasoningEffort.AUTO: "medium",
|
||||
ReasoningEffort.LOW: "low",
|
||||
ReasoningEffort.MEDIUM: "medium",
|
||||
ReasoningEffort.HIGH: "high",
|
||||
}
|
||||
|
||||
|
||||
# Content part structures for multimodal messages
|
||||
# The classes in this mirror the OpenAI Chat Completions message types and work well with routers like LiteLLM
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.model_response import Usage
|
||||
from onyx.llm.models import ANTHROPIC_ADAPTIVE_REASONING_EFFORT
|
||||
from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET
|
||||
from onyx.llm.models import OPENAI_REASONING_EFFORT
|
||||
from onyx.llm.request_context import get_llm_mock_response
|
||||
@@ -67,8 +68,13 @@ STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
_VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG = (
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
)
|
||||
|
||||
# Anthropic models that require the adaptive thinking API (thinking.type.adaptive
|
||||
# + output_config.effort) instead of the legacy thinking.type.enabled + budget_tokens.
|
||||
_ANTHROPIC_ADAPTIVE_THINKING_MODELS = ("claude-opus-4-7",)
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
@@ -230,6 +236,14 @@ def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _anthropic_uses_adaptive_thinking(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
adaptive_model in normalized_model_name
|
||||
for adaptive_model in _ANTHROPIC_ADAPTIVE_THINKING_MODELS
|
||||
)
|
||||
|
||||
|
||||
class LitellmLLM(LLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
@@ -509,10 +523,6 @@ class LitellmLLM(LLM):
|
||||
}
|
||||
|
||||
elif is_claude_model:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
@@ -520,24 +530,35 @@ class LitellmLLM(LLM):
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
has_tool_call_history = _prompt_contains_tool_call_history(prompt)
|
||||
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
if _anthropic_uses_adaptive_thinking(self.config.model_name):
|
||||
# Newer Anthropic models (Claude Opus 4.7+) reject
|
||||
# thinking.type.enabled — they require the adaptive
|
||||
# thinking config with output_config.effort.
|
||||
if not has_tool_call_history:
|
||||
optional_kwargs["thinking"] = {"type": "adaptive"}
|
||||
optional_kwargs["output_config"] = {
|
||||
"effort": ANTHROPIC_ADAPTIVE_REASONING_EFFORT[
|
||||
reasoning_effort
|
||||
],
|
||||
}
|
||||
else:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
if budget_tokens is not None and not has_tool_call_history:
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
|
||||
# LiteLLM just does some mapping like this anyway but is incomplete for Anthropic
|
||||
optional_kwargs.pop("reasoning_effort", None)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"version": "1.2",
|
||||
"updated_at": "2026-04-16T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
@@ -10,8 +10,12 @@
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
"default_model": "claude-opus-4-6",
|
||||
"default_model": "claude-opus-4-7",
|
||||
"additional_visible_models": [
|
||||
{
|
||||
"name": "claude-opus-4-7",
|
||||
"display_name": "Claude Opus 4.7"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
|
||||
@@ -65,8 +65,9 @@ IMPORTANT: each call to this tool is independent. Variables from previous calls
|
||||
GENERATE_IMAGE_GUIDANCE = """
|
||||
## generate_image
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
|
||||
the `file_id` values returned by earlier `generate_image` tool results.
|
||||
To edit, restyle, or vary an existing image, pass its file_id in `reference_image_file_ids`. \
|
||||
File IDs come from `[attached image — file_id: <id>]` tags on user-attached images or from prior `generate_image` tool results — never invent one. \
|
||||
Leave `reference_image_file_ids` unset for a fresh generation.
|
||||
""".lstrip()
|
||||
|
||||
MEMORY_GUIDANCE = """
|
||||
|
||||
@@ -618,6 +618,7 @@ done
|
||||
"app.kubernetes.io/managed-by": "onyx",
|
||||
"onyx.app/sandbox-id": sandbox_id,
|
||||
"onyx.app/tenant-id": tenant_id,
|
||||
"admission.datadoghq.com/enabled": "false",
|
||||
},
|
||||
),
|
||||
spec=pod_spec,
|
||||
|
||||
@@ -11,6 +11,9 @@ from onyx.db.notification import dismiss_notification
|
||||
from onyx.db.notification import get_notification_by_id
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
|
||||
from onyx.server.features.notifications.utils import (
|
||||
ensure_permissions_migration_notification,
|
||||
)
|
||||
from onyx.server.features.release_notes.utils import (
|
||||
ensure_release_notes_fresh_and_notify,
|
||||
)
|
||||
@@ -49,6 +52,13 @@ def get_notifications_api(
|
||||
except Exception:
|
||||
logger.exception("Failed to check for release notes in notifications endpoint")
|
||||
|
||||
try:
|
||||
ensure_permissions_migration_notification(user, db_session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
|
||||
)
|
||||
|
||||
notifications = [
|
||||
NotificationModel.from_model(notif)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=True)
|
||||
|
||||
21
backend/onyx/server/features/notifications/utils.py
Normal file
21
backend/onyx/server/features/notifications/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
|
||||
|
||||
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
|
||||
# Feature id "permissions_migration_v1" must not change after shipping —
|
||||
# it is the dedup key on (user_id, notif_type, additional_data).
|
||||
create_notification(
|
||||
user_id=user.id,
|
||||
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
|
||||
db_session=db_session,
|
||||
title="Permissions are changing in Onyx",
|
||||
description="Roles are moving to group-based permissions. Click for details.",
|
||||
additional_data={
|
||||
"feature": "permissions_migration_v1",
|
||||
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
|
||||
},
|
||||
)
|
||||
@@ -183,6 +183,9 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
"qwen2.5:7b" → "Qwen 2.5 7B"
|
||||
"mistral:latest" → "Mistral"
|
||||
"deepseek-r1:14b" → "DeepSeek R1 14B"
|
||||
"gemma4:e4b" → "Gemma 4 E4B"
|
||||
"deepseek-v3.1:671b-cloud" → "DeepSeek V3.1 671B Cloud"
|
||||
"qwen3-vl:235b-instruct-cloud" → "Qwen 3-vl 235B Instruct Cloud"
|
||||
"""
|
||||
# Split into base name and tag
|
||||
if ":" in model_name:
|
||||
@@ -209,13 +212,24 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
# Default: Title case with dashes converted to spaces
|
||||
display_name = base.replace("-", " ").title()
|
||||
|
||||
# Process tag to extract size info (skip "latest")
|
||||
# Process tag (skip "latest")
|
||||
if tag and tag.lower() != "latest":
|
||||
# Extract size like "7b", "70b", "14b"
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])", tag)
|
||||
# Check for size prefix like "7b", "70b", optionally followed by modifiers
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])(-.+)?$", tag)
|
||||
if size_match:
|
||||
size = size_match.group(1).upper()
|
||||
display_name = f"{display_name} {size}"
|
||||
remainder = size_match.group(2)
|
||||
if remainder:
|
||||
# Format modifiers like "-cloud", "-instruct-cloud"
|
||||
modifiers = " ".join(
|
||||
p.title() for p in remainder.strip("-").split("-") if p
|
||||
)
|
||||
display_name = f"{display_name} {size} {modifiers}"
|
||||
else:
|
||||
display_name = f"{display_name} {size}"
|
||||
else:
|
||||
# Non-size tags like "e4b", "q4_0", "fp16", "cloud"
|
||||
display_name = f"{display_name} {tag.upper()}"
|
||||
|
||||
return display_name
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import json
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
@@ -113,28 +114,47 @@ async def transcribe_audio(
|
||||
) from exc
|
||||
|
||||
|
||||
def _extract_provider_error(exc: Exception) -> str:
|
||||
"""Extract a human-readable message from a provider exception.
|
||||
|
||||
Provider errors often embed JSON from upstream APIs (e.g. ElevenLabs).
|
||||
This tries to parse a readable ``message`` field out of common JSON
|
||||
error shapes; falls back to ``str(exc)`` if nothing better is found.
|
||||
"""
|
||||
raw = str(exc)
|
||||
try:
|
||||
# Many providers embed JSON after a prefix like "ElevenLabs TTS failed: {...}"
|
||||
json_start = raw.find("{")
|
||||
if json_start == -1:
|
||||
return raw
|
||||
parsed = json.loads(raw[json_start:])
|
||||
# Shape: {"detail": {"message": "..."}} (ElevenLabs)
|
||||
detail = parsed.get("detail", parsed)
|
||||
if isinstance(detail, dict):
|
||||
return detail.get("message") or detail.get("error") or raw
|
||||
if isinstance(detail, str):
|
||||
return detail
|
||||
except (json.JSONDecodeError, AttributeError, TypeError):
|
||||
pass
|
||||
return raw
|
||||
|
||||
|
||||
class SynthesizeRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1)
|
||||
voice: str | None = None
|
||||
speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
body: SynthesizeRequest,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
"""Synthesize text to speech using the default TTS provider."""
|
||||
text = body.text
|
||||
voice = body.voice
|
||||
speed = body.speed
|
||||
logger.info(f"TTS request: text length={len(text)}, voice={voice}, speed={speed}")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
@@ -177,31 +197,36 @@ async def synthesize_speech(
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
# Pull the first chunk before returning the StreamingResponse. If the
|
||||
# provider rejects the request (e.g. text too long), the error surfaces
|
||||
# as a proper HTTP error instead of a broken audio stream.
|
||||
stream_iter = provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
)
|
||||
try:
|
||||
first_chunk = await stream_iter.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "TTS provider returned no audio")
|
||||
except Exception as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, _extract_provider_error(exc)
|
||||
) from exc
|
||||
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
yield first_chunk
|
||||
chunk_count = 1
|
||||
async for chunk in stream_iter:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -208,12 +208,6 @@ class PythonToolOverrideKwargs(BaseModel):
|
||||
chat_files: list[ChatFile] = []
|
||||
|
||||
|
||||
class ImageGenerationToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for image generation tool calls."""
|
||||
|
||||
recent_generated_image_file_ids: list[str] = []
|
||||
|
||||
|
||||
class SearchToolRunContext(BaseModel):
|
||||
emitter: Emitter
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -113,10 +114,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
db_session: Session | None = None,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
@@ -131,6 +132,33 @@ def construct_tools(
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
return _construct_tools_impl(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
custom_tool_config=custom_tool_config,
|
||||
file_reader_tool_config=file_reader_tool_config,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
|
||||
|
||||
def _construct_tools_impl(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
|
||||
) -> dict[int, list[Tool]]:
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -48,7 +47,7 @@ PROMPT_FIELD = "prompt"
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
NAME = "generate_image"
|
||||
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
|
||||
DISPLAY_NAME = "Image Generation"
|
||||
@@ -142,8 +141,11 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD: {
|
||||
"type": "array",
|
||||
"description": (
|
||||
"Optional image file IDs to use as reference context for edits/variations. "
|
||||
"Use the file_id values returned by previous generate_image calls."
|
||||
"Optional file_ids of existing images to edit or use as reference;"
|
||||
" the first is the primary edit source."
|
||||
" Get file_ids from `[attached image — file_id: <id>]` tags on"
|
||||
" user-attached images or from prior generate_image tool responses."
|
||||
" Omit for a fresh, unrelated generation."
|
||||
),
|
||||
"items": {
|
||||
"type": "string",
|
||||
@@ -254,41 +256,31 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
def _resolve_reference_image_file_ids(
|
||||
self,
|
||||
llm_kwargs: dict[str, Any],
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None,
|
||||
) -> list[str]:
|
||||
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
|
||||
if raw_reference_ids is not None:
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
reference_image_file_ids = [
|
||||
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
|
||||
]
|
||||
elif (
|
||||
override_kwargs
|
||||
and override_kwargs.recent_generated_image_file_ids
|
||||
and self.img_provider.supports_reference_images
|
||||
):
|
||||
# If no explicit reference was provided, default to the most recently generated image.
|
||||
reference_image_file_ids = [
|
||||
override_kwargs.recent_generated_image_file_ids[-1]
|
||||
]
|
||||
else:
|
||||
reference_image_file_ids = []
|
||||
if raw_reference_ids is None:
|
||||
# No references requested — plain generation.
|
||||
return []
|
||||
|
||||
# Deduplicate while preserving order.
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
|
||||
# Deduplicate while preserving order (first occurrence wins, so the
|
||||
# LLM's intended "primary edit source" stays at index 0).
|
||||
deduped_reference_image_ids: list[str] = []
|
||||
seen_ids: set[str] = set()
|
||||
for file_id in reference_image_file_ids:
|
||||
if file_id in seen_ids:
|
||||
for file_id in raw_reference_ids:
|
||||
file_id = file_id.strip()
|
||||
if not file_id or file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
deduped_reference_image_ids.append(file_id)
|
||||
@@ -302,14 +294,14 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
f"Reference images requested but provider '{self.provider}' does not support image-editing context."
|
||||
),
|
||||
llm_facing_message=(
|
||||
"This image provider does not support editing from previous image context. "
|
||||
"This image provider does not support editing from existing images. "
|
||||
"Try text-only generation, or switch to a provider/model that supports image edits."
|
||||
),
|
||||
)
|
||||
|
||||
max_reference_images = self.img_provider.max_reference_images
|
||||
if max_reference_images > 0:
|
||||
return deduped_reference_image_ids[-max_reference_images:]
|
||||
return deduped_reference_image_ids[:max_reference_images]
|
||||
return deduped_reference_image_ids
|
||||
|
||||
def _load_reference_images(
|
||||
@@ -358,7 +350,7 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
if PROMPT_FIELD not in llm_kwargs:
|
||||
@@ -373,7 +365,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
|
||||
reference_image_file_ids = self._resolve_reference_image_file_ids(
|
||||
llm_kwargs=llm_kwargs,
|
||||
override_kwargs=override_kwargs,
|
||||
)
|
||||
reference_images = self._load_reference_images(reference_image_file_ids)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
@@ -14,7 +13,6 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import OpenURLToolOverrideKwargs
|
||||
from onyx.tools.models import ParallelToolCallResponse
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
@@ -24,9 +22,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
@@ -110,63 +105,6 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
|
||||
return merged_calls
|
||||
|
||||
|
||||
def _extract_image_file_ids_from_tool_response_message(
|
||||
message: str,
|
||||
) -> list[str]:
|
||||
try:
|
||||
parsed_message = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
parsed_items: list[Any] = (
|
||||
parsed_message if isinstance(parsed_message, list) else [parsed_message]
|
||||
)
|
||||
file_ids: list[str] = []
|
||||
for item in parsed_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
file_id = item.get("file_id")
|
||||
if isinstance(file_id, str):
|
||||
file_ids.append(file_id)
|
||||
|
||||
return file_ids
|
||||
|
||||
|
||||
def _extract_recent_generated_image_file_ids(
|
||||
message_history: list[ChatMessageSimple],
|
||||
) -> list[str]:
|
||||
tool_name_by_tool_call_id: dict[str, str] = {}
|
||||
recent_image_file_ids: list[str] = []
|
||||
seen_file_ids: set[str] = set()
|
||||
|
||||
for message in message_history:
|
||||
if message.message_type == MessageType.ASSISTANT and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name_by_tool_call_id[tool_call.tool_call_id] = tool_call.tool_name
|
||||
continue
|
||||
|
||||
if (
|
||||
message.message_type != MessageType.TOOL_CALL_RESPONSE
|
||||
or not message.tool_call_id
|
||||
):
|
||||
continue
|
||||
|
||||
tool_name = tool_name_by_tool_call_id.get(message.tool_call_id)
|
||||
if tool_name != ImageGenerationTool.NAME:
|
||||
continue
|
||||
|
||||
for file_id in _extract_image_file_ids_from_tool_response_message(
|
||||
message.message
|
||||
):
|
||||
if file_id in seen_file_ids:
|
||||
continue
|
||||
seen_file_ids.add(file_id)
|
||||
recent_image_file_ids.append(file_id)
|
||||
|
||||
return recent_image_file_ids
|
||||
|
||||
|
||||
def _safe_run_single_tool(
|
||||
tool: Tool,
|
||||
tool_call: ToolCallKickoff,
|
||||
@@ -386,9 +324,6 @@ def run_tool_calls(
|
||||
url_to_citation: dict[str, int] = {
|
||||
url: citation_num for citation_num, url in citation_mapping.items()
|
||||
}
|
||||
recent_generated_image_file_ids = _extract_recent_generated_image_file_ids(
|
||||
message_history
|
||||
)
|
||||
|
||||
# Prepare all tool calls with their override_kwargs
|
||||
# Each tool gets a unique starting citation number to avoid conflicts when running in parallel
|
||||
@@ -405,7 +340,6 @@ def run_tool_calls(
|
||||
| WebSearchToolOverrideKwargs
|
||||
| OpenURLToolOverrideKwargs
|
||||
| PythonToolOverrideKwargs
|
||||
| ImageGenerationToolOverrideKwargs
|
||||
| MemoryToolOverrideKwargs
|
||||
| None
|
||||
) = None
|
||||
@@ -454,10 +388,6 @@ def run_tool_calls(
|
||||
override_kwargs = PythonToolOverrideKwargs(
|
||||
chat_files=chat_files or [],
|
||||
)
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
override_kwargs = ImageGenerationToolOverrideKwargs(
|
||||
recent_generated_image_file_ids=recent_generated_image_file_ids
|
||||
)
|
||||
elif isinstance(tool, MemoryTool):
|
||||
override_kwargs = MemoryToolOverrideKwargs(
|
||||
user_name=(
|
||||
|
||||
@@ -38,38 +38,41 @@ class TestAddMemory:
|
||||
def test_add_memory_creates_row(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that add_memory inserts a new Memory row."""
|
||||
user_id = test_user.id
|
||||
memory = add_memory(
|
||||
memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="User prefers dark mode",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert memory.id is not None
|
||||
assert memory.user_id == user_id
|
||||
assert memory.memory_text == "User prefers dark mode"
|
||||
assert memory_id is not None
|
||||
|
||||
# Verify it persists
|
||||
fetched = db_session.get(Memory, memory.id)
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.user_id == user_id
|
||||
assert fetched.memory_text == "User prefers dark mode"
|
||||
|
||||
def test_add_multiple_memories(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that multiple memories can be added for the same user."""
|
||||
user_id = test_user.id
|
||||
m1 = add_memory(
|
||||
m1_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Favorite color is blue",
|
||||
db_session=db_session,
|
||||
)
|
||||
m2 = add_memory(
|
||||
m2_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Works in engineering",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert m1.id != m2.id
|
||||
assert m1.memory_text == "Favorite color is blue"
|
||||
assert m2.memory_text == "Works in engineering"
|
||||
assert m1_id != m2_id
|
||||
fetched_m1 = db_session.get(Memory, m1_id)
|
||||
fetched_m2 = db_session.get(Memory, m2_id)
|
||||
assert fetched_m1 is not None
|
||||
assert fetched_m2 is not None
|
||||
assert fetched_m1.memory_text == "Favorite color is blue"
|
||||
assert fetched_m2.memory_text == "Works in engineering"
|
||||
|
||||
|
||||
class TestUpdateMemoryAtIndex:
|
||||
@@ -82,15 +85,17 @@ class TestUpdateMemoryAtIndex:
|
||||
add_memory(user_id=user_id, memory_text="Memory 1", db_session=db_session)
|
||||
add_memory(user_id=user_id, memory_text="Memory 2", db_session=db_session)
|
||||
|
||||
updated = update_memory_at_index(
|
||||
updated_id = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=1,
|
||||
new_text="Updated Memory 1",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated Memory 1"
|
||||
assert updated_id is not None
|
||||
fetched = db_session.get(Memory, updated_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Updated Memory 1"
|
||||
|
||||
def test_update_memory_at_out_of_range_index(
|
||||
self, db_session: Session, test_user: User
|
||||
@@ -167,7 +172,7 @@ class TestMemoryCap:
|
||||
assert len(rows_before) == MAX_MEMORIES_PER_USER
|
||||
|
||||
# Add one more — should evict the oldest
|
||||
new_memory = add_memory(
|
||||
new_memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="New memory after cap",
|
||||
db_session=db_session,
|
||||
@@ -181,7 +186,7 @@ class TestMemoryCap:
|
||||
# Oldest ("Memory 0") should be gone; "Memory 1" is now the oldest
|
||||
assert rows_after[0].memory_text == "Memory 1"
|
||||
# Newest should be the one we just added
|
||||
assert rows_after[-1].id == new_memory.id
|
||||
assert rows_after[-1].id == new_memory_id
|
||||
assert rows_after[-1].memory_text == "New memory after cap"
|
||||
|
||||
|
||||
@@ -221,22 +226,26 @@ class TestGetMemoriesWithUserId:
|
||||
user_id = test_user_no_memories.id
|
||||
|
||||
# Add a memory
|
||||
memory = add_memory(
|
||||
memory_id = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert memory.memory_text == "Memory with use_memories off"
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Memory with use_memories off"
|
||||
|
||||
# Update that memory
|
||||
updated = update_memory_at_index(
|
||||
updated_id = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=0,
|
||||
new_text="Updated memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated memory with use_memories off"
|
||||
assert updated_id is not None
|
||||
fetched_updated = db_session.get(Memory, updated_id)
|
||||
assert fetched_updated is not None
|
||||
assert fetched_updated.memory_text == "Updated memory with use_memories off"
|
||||
|
||||
# Verify get_memories returns the updated memory
|
||||
context = get_memories(test_user_no_memories, db_session)
|
||||
|
||||
@@ -301,7 +301,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -332,7 +331,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -363,7 +361,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -391,7 +388,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -423,7 +419,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -456,7 +451,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -497,7 +491,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -519,7 +512,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -542,7 +534,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -596,7 +587,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -653,7 +643,6 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -706,7 +695,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -736,7 +724,6 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.llm.utils import get_max_input_tokens
|
||||
VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG = [
|
||||
"claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -100,6 +100,39 @@ class TestGenerateOllamaDisplayName:
|
||||
result = generate_ollama_display_name("llama3.3:70b")
|
||||
assert "3.3" in result or "3 3" in result # Either format is acceptable
|
||||
|
||||
def test_non_size_tag_shown(self) -> None:
|
||||
"""Test that non-size tags like 'e4b' are included in the display name."""
|
||||
result = generate_ollama_display_name("gemma4:e4b")
|
||||
assert "Gemma" in result
|
||||
assert "4" in result
|
||||
assert "E4B" in result
|
||||
|
||||
def test_size_with_cloud_modifier(self) -> None:
|
||||
"""Test size tag with cloud modifier."""
|
||||
result = generate_ollama_display_name("deepseek-v3.1:671b-cloud")
|
||||
assert "DeepSeek" in result
|
||||
assert "671B" in result
|
||||
assert "Cloud" in result
|
||||
|
||||
def test_size_with_multiple_modifiers(self) -> None:
|
||||
"""Test size tag with multiple modifiers."""
|
||||
result = generate_ollama_display_name("qwen3-vl:235b-instruct-cloud")
|
||||
assert "Qwen" in result
|
||||
assert "235B" in result
|
||||
assert "Instruct" in result
|
||||
assert "Cloud" in result
|
||||
|
||||
def test_quantization_tag_shown(self) -> None:
|
||||
"""Test that quantization tags are included in the display name."""
|
||||
result = generate_ollama_display_name("llama3:q4_0")
|
||||
assert "Llama" in result
|
||||
assert "Q4_0" in result
|
||||
|
||||
def test_cloud_only_tag(self) -> None:
|
||||
"""Test standalone cloud tag."""
|
||||
result = generate_ollama_display_name("glm-4.6:cloud")
|
||||
assert "CLOUD" in result
|
||||
|
||||
|
||||
class TestStripOpenrouterVendorPrefix:
|
||||
"""Tests for OpenRouter vendor prefix stripping."""
|
||||
|
||||
@@ -95,9 +95,9 @@ class TestForceAddSearchToolGuard:
|
||||
without a vector DB."""
|
||||
import inspect
|
||||
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import _construct_tools_impl
|
||||
|
||||
source = inspect.getsource(construct_tools)
|
||||
source = inspect.getsource(_construct_tools_impl)
|
||||
assert (
|
||||
"DISABLE_VECTOR_DB" in source
|
||||
), "construct_tools should reference DISABLE_VECTOR_DB to suppress force-adding SearchTool"
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Tests for ``ImageGenerationTool._resolve_reference_image_file_ids``.
|
||||
|
||||
The resolver turns the LLM's ``reference_image_file_ids`` argument into a
|
||||
cleaned list of file IDs to hand to ``_load_reference_images``. It trusts
|
||||
the LLM's picks — the LLM can only see file IDs that actually appear in
|
||||
the conversation (via ``[attached image — file_id: <id>]`` tags on user
|
||||
messages and the JSON returned by prior generate_image calls), so we
|
||||
don't re-validate against an allow-list in the tool itself.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD,
|
||||
)
|
||||
|
||||
|
||||
def _make_tool(
|
||||
supports_reference_images: bool = True,
|
||||
max_reference_images: int = 16,
|
||||
) -> ImageGenerationTool:
|
||||
"""Construct a tool with a mock provider so no credentials/network are needed."""
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.images.image_generation_tool.get_image_generation_provider"
|
||||
) as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.supports_reference_images = supports_reference_images
|
||||
mock_provider.max_reference_images = max_reference_images
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
return ImageGenerationTool(
|
||||
image_generation_credentials=MagicMock(),
|
||||
tool_id=1,
|
||||
emitter=MagicMock(),
|
||||
model="gpt-image-1",
|
||||
provider="openai",
|
||||
)
|
||||
|
||||
|
||||
class TestResolveReferenceImageFileIds:
|
||||
def test_unset_returns_empty_plain_generation(self) -> None:
|
||||
tool = _make_tool()
|
||||
assert tool._resolve_reference_image_file_ids(llm_kwargs={}) == []
|
||||
|
||||
def test_empty_list_is_treated_like_unset(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: []},
|
||||
)
|
||||
assert result == []
|
||||
|
||||
def test_passes_llm_supplied_ids_through(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["upload-1", "gen-1"]},
|
||||
)
|
||||
# Order preserved — first entry is the primary edit source.
|
||||
assert result == ["upload-1", "gen-1"]
|
||||
|
||||
def test_invalid_shape_raises(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: "not-a-list"},
|
||||
)
|
||||
|
||||
def test_non_string_element_raises(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["ok", 123]},
|
||||
)
|
||||
|
||||
def test_deduplicates_preserving_first_occurrence(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1", "gen-2", "gen-1"]},
|
||||
)
|
||||
assert result == ["gen-1", "gen-2"]
|
||||
|
||||
def test_strips_whitespace_and_skips_empty_strings(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: [" gen-1 ", "", " "]},
|
||||
)
|
||||
assert result == ["gen-1"]
|
||||
|
||||
def test_provider_without_reference_support_raises(self) -> None:
|
||||
tool = _make_tool(supports_reference_images=False)
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1"]},
|
||||
)
|
||||
|
||||
def test_truncates_to_provider_max_preserving_head(self) -> None:
|
||||
"""When the LLM lists more images than the provider allows, keep the
|
||||
HEAD of the list (the primary edit source + earliest extras) rather
|
||||
than the tail, since the LLM put the most important one first."""
|
||||
tool = _make_tool(max_reference_images=2)
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["a", "b", "c", "d"]},
|
||||
)
|
||||
assert result == ["a", "b"]
|
||||
@@ -1,10 +1,5 @@
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_runner import _extract_image_file_ids_from_tool_response_message
|
||||
from onyx.tools.tool_runner import _extract_recent_generated_image_file_ids
|
||||
from onyx.tools.tool_runner import _merge_tool_calls
|
||||
|
||||
|
||||
@@ -312,62 +307,3 @@ class TestMergeToolCalls:
|
||||
assert len(result) == 1
|
||||
# String should be converted to list item
|
||||
assert result[0].tool_args["queries"] == ["single_query", "q2"]
|
||||
|
||||
|
||||
class TestImageHistoryExtraction:
|
||||
def test_extracts_image_file_ids_from_json_response(self) -> None:
|
||||
msg = '[{"file_id":"img-1","revised_prompt":"v1"},{"file_id":"img-2","revised_prompt":"v2"}]'
|
||||
assert _extract_image_file_ids_from_tool_response_message(msg) == [
|
||||
"img-1",
|
||||
"img-2",
|
||||
]
|
||||
|
||||
def test_extracts_recent_generated_image_ids_from_history(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="generate_image",
|
||||
tool_arguments={"prompt": "test"},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == ["img-1"]
|
||||
|
||||
def test_ignores_non_image_tool_responses(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_arguments={"queries": ["q"]},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == []
|
||||
|
||||
@@ -15,6 +15,7 @@ type InteractiveStatefulVariant =
|
||||
| "select-heavy"
|
||||
| "select-card"
|
||||
| "select-tinted"
|
||||
| "select-input"
|
||||
| "select-filter"
|
||||
| "sidebar-heavy"
|
||||
| "sidebar-light";
|
||||
@@ -35,6 +36,7 @@ interface InteractiveStatefulProps
|
||||
* - `"select-heavy"` — tinted selected background (for list rows, model pickers)
|
||||
* - `"select-card"` — like select-heavy but filled state has a visible background (for cards/larger surfaces)
|
||||
* - `"select-tinted"` — like select-heavy but with a tinted rest background
|
||||
* - `"select-input"` — rests at neutral-00 (matches input bar), hover/open shows neutral-03 + border-01
|
||||
* - `"select-filter"` — like select-tinted for empty/filled; selected state uses inverted tint backgrounds and inverted text (for filter buttons)
|
||||
* - `"sidebar-heavy"` — sidebar navigation items: muted when unselected (text-03/text-02), bold when selected (text-04/text-03)
|
||||
* - `"sidebar-light"` — sidebar navigation items: uniformly muted across all states (text-02/text-02)
|
||||
|
||||
@@ -350,6 +350,41 @@
|
||||
--interactive-foreground-icon: var(--text-01);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Select-Input — Empty
|
||||
Matches input bar background at rest, tints on hover/open.
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"] {
|
||||
@apply bg-background-neutral-00;
|
||||
--interactive-foreground: var(--text-04);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"]:hover:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-interaction="hover"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-neutral-03;
|
||||
--interactive-foreground: var(--text-04);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"]:active:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-interaction="active"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-neutral-03;
|
||||
--interactive-foreground: var(--text-05);
|
||||
--interactive-foreground-icon: var(--text-05);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-disabled] {
|
||||
@apply bg-transparent;
|
||||
--interactive-foreground: var(--text-01);
|
||||
--interactive-foreground-icon: var(--text-01);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Select-Tinted — Filled
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
@@ -73,7 +73,10 @@ export const MemoizedAnchor = memo(
|
||||
: undefined;
|
||||
|
||||
if (!associatedDoc && !associatedSubQuestion) {
|
||||
return <>{children}</>;
|
||||
// Citation not resolved yet (data still streaming) — hide the
|
||||
// raw [[N]](url) link entirely. It will render as a chip once
|
||||
// the citation/document data arrives.
|
||||
return <></>;
|
||||
}
|
||||
|
||||
let icon: React.ReactNode = null;
|
||||
|
||||
@@ -44,6 +44,8 @@ export interface MultiModelPanelProps {
|
||||
errorStackTrace?: string | null;
|
||||
/** Additional error details */
|
||||
errorDetails?: Record<string, any> | null;
|
||||
/** Whether any model is still streaming — disables preferred selection */
|
||||
isGenerating?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -73,19 +75,24 @@ export default function MultiModelPanel({
|
||||
isRetryable,
|
||||
errorStackTrace,
|
||||
errorDetails,
|
||||
isGenerating,
|
||||
}: MultiModelPanelProps) {
|
||||
const ModelIcon = getModelIcon(provider, modelName);
|
||||
|
||||
const canSelect = !isHidden && !isPreferred && !isGenerating;
|
||||
|
||||
const handlePanelClick = useCallback(() => {
|
||||
if (!isHidden && !isPreferred) onSelect();
|
||||
}, [isHidden, isPreferred, onSelect]);
|
||||
if (canSelect) onSelect();
|
||||
}, [canSelect, onSelect]);
|
||||
|
||||
const header = (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-12",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
|
||||
"rounded-12 transition-colors",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00",
|
||||
canSelect && "cursor-pointer hover:bg-background-tint-02"
|
||||
)}
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
@@ -140,13 +147,7 @@ export default function MultiModelPanel({
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-3 min-w-0 rounded-16 transition-colors",
|
||||
!isPreferred && "cursor-pointer hover:bg-background-tint-02"
|
||||
)}
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
<div className="flex flex-col gap-3 min-w-0 rounded-16">
|
||||
{header}
|
||||
{errorMessage ? (
|
||||
<div className="p-4">
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useCallback, useMemo, useEffect, useRef } from "react";
|
||||
import {
|
||||
useState,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useRef,
|
||||
} from "react";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
@@ -110,11 +117,27 @@ export default function MultiModelResponseView({
|
||||
// Refs to each panel wrapper for height animation on deselect
|
||||
const panelElsRef = useRef<Map<number, HTMLDivElement>>(new Map());
|
||||
|
||||
// Tracks which non-preferred panels overflow the preferred height cap
|
||||
// Tracks which non-preferred panels overflow the preferred height cap.
|
||||
// Measured via useLayoutEffect after maxHeight is applied to the DOM —
|
||||
// ref callbacks fire before layout and can't reliably detect overflow.
|
||||
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
|
||||
new Set()
|
||||
);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (preferredPanelHeight == null || preferredIndex === null) return;
|
||||
const next = new Set<number>();
|
||||
panelElsRef.current.forEach((el, idx) => {
|
||||
if (idx === preferredIndex || hiddenPanels.has(idx)) return;
|
||||
if (el.scrollHeight > el.clientHeight) next.add(idx);
|
||||
});
|
||||
setOverflowingPanels((prev) => {
|
||||
if (prev.size === next.size && Array.from(prev).every((v) => next.has(v)))
|
||||
return prev;
|
||||
return next;
|
||||
});
|
||||
}, [preferredPanelHeight, preferredIndex, hiddenPanels, responses]);
|
||||
|
||||
const preferredPanelRef = useCallback((el: HTMLDivElement | null) => {
|
||||
if (preferredRoRef.current) {
|
||||
preferredRoRef.current.disconnect();
|
||||
@@ -416,6 +439,7 @@ export default function MultiModelResponseView({
|
||||
isRetryable: response.isRetryable,
|
||||
errorStackTrace: response.errorStackTrace,
|
||||
errorDetails: response.errorDetails,
|
||||
isGenerating,
|
||||
}),
|
||||
[
|
||||
preferredIndex,
|
||||
@@ -429,6 +453,7 @@ export default function MultiModelResponseView({
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
isGenerating,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -515,17 +540,6 @@ export default function MultiModelResponseView({
|
||||
panelElsRef.current.delete(r.modelIndex);
|
||||
}
|
||||
if (isPref) preferredPanelRef(el);
|
||||
if (capped && el) {
|
||||
const doesOverflow = el.scrollHeight > el.clientHeight;
|
||||
setOverflowingPanels((prev) => {
|
||||
const had = prev.has(r.modelIndex);
|
||||
if (doesOverflow === had) return prev;
|
||||
const next = new Set(prev);
|
||||
if (doesOverflow) next.add(r.modelIndex);
|
||||
else next.delete(r.modelIndex);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
width: `${selectionEntered ? finalW : startW}px`,
|
||||
@@ -536,21 +550,19 @@ export default function MultiModelResponseView({
|
||||
: "none",
|
||||
maxHeight: capped ? preferredPanelHeight : undefined,
|
||||
overflow: capped ? "hidden" : undefined,
|
||||
position: capped ? "relative" : undefined,
|
||||
...(overflows
|
||||
? {
|
||||
maskImage:
|
||||
"linear-gradient(to bottom, black calc(100% - 6rem), transparent 100%)",
|
||||
WebkitMaskImage:
|
||||
"linear-gradient(to bottom, black calc(100% - 6rem), transparent 100%)",
|
||||
}
|
||||
: {}),
|
||||
}}
|
||||
>
|
||||
<div className={cn(isNonPref && "opacity-50")}>
|
||||
<MultiModelPanel {...buildPanelProps(r, isNonPref)} />
|
||||
</div>
|
||||
{overflows && (
|
||||
<div
|
||||
className="absolute inset-x-0 bottom-0 h-24 pointer-events-none"
|
||||
style={{
|
||||
background:
|
||||
"linear-gradient(to top, var(--background-tint-01) 0%, transparent 100%)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -136,32 +136,49 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
finalAnswerComing
|
||||
);
|
||||
|
||||
// Memoize merged citations separately to avoid creating new object when neither source changed
|
||||
// Merge streaming citation/document data with chatState props.
|
||||
// NOTE: citationMap and documentMap from usePacketProcessor are mutated in
|
||||
// place (same object reference), so we use citations.length / documentMap.size
|
||||
// as change-detection proxies to bust the memo cache when new data arrives.
|
||||
const mergedCitations = useMemo(
|
||||
() => ({
|
||||
...chatState.citations,
|
||||
...citationMap,
|
||||
}),
|
||||
[chatState.citations, citationMap]
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[chatState.citations, citationMap, citations.length]
|
||||
);
|
||||
|
||||
// Create a chatState that uses streaming citations for immediate rendering
|
||||
// This merges the prop citations with streaming citations, preferring streaming ones
|
||||
// Memoized with granular dependencies to prevent cascading re-renders
|
||||
// Merge streaming documentMap into chatState.docs so inline citation chips
|
||||
// can resolve [1] → document even when chatState.docs is empty (multi-model).
|
||||
const mergedDocs = useMemo(() => {
|
||||
const propDocs = chatState.docs ?? [];
|
||||
if (documentMap.size === 0) return propDocs;
|
||||
const seen = new Set(propDocs.map((d) => d.document_id));
|
||||
const extras = Array.from(documentMap.values()).filter(
|
||||
(d) => !seen.has(d.document_id)
|
||||
);
|
||||
return extras.length > 0 ? [...propDocs, ...extras] : propDocs;
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [chatState.docs, documentMap, documentMap.size]);
|
||||
|
||||
// Create a chatState that uses streaming citations and documents for immediate rendering.
|
||||
// Memoized with granular dependencies to prevent cascading re-renders.
|
||||
// Note: chatState object is recreated upstream on every render, so we depend on
|
||||
// individual fields instead of the whole object for proper memoization
|
||||
// individual fields instead of the whole object for proper memoization.
|
||||
const effectiveChatState = useMemo<FullChatState>(
|
||||
() => ({
|
||||
...chatState,
|
||||
citations: mergedCitations,
|
||||
docs: mergedDocs,
|
||||
}),
|
||||
[
|
||||
chatState.agent,
|
||||
chatState.docs,
|
||||
chatState.setPresentingDocument,
|
||||
chatState.overriddenModel,
|
||||
chatState.researchType,
|
||||
mergedCitations,
|
||||
mergedDocs,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
@@ -59,7 +59,6 @@ function TTSButton({ text, voice, speed }: TTSButtonProps) {
|
||||
// Surface streaming voice playback errors to the user via toast
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
console.error("Voice playback error:", error);
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
@@ -81,6 +81,15 @@ export function ScrollableTable({
|
||||
* Processes content for markdown rendering by handling code blocks and LaTeX
|
||||
*/
|
||||
export const processContent = (content: string): string => {
|
||||
// Strip incomplete citation links at the end of streaming content.
|
||||
// During typewriter animation, [[N]](url) is revealed character by character.
|
||||
// ReactMarkdown can't parse an incomplete link and renders it as raw text.
|
||||
// This regex removes any trailing partial citation pattern so only complete
|
||||
// links are passed to the markdown parser.
|
||||
content = content.replace(/\[\[\d+\]\]\([^)]*$/, "");
|
||||
// Also strip a lone [[ or [[N] or [[N]] at the very end (before the URL part arrives)
|
||||
content = content.replace(/\[\[(?:\d+\]?\]?)?$/, "");
|
||||
|
||||
const codeBlockRegex = /```(\w*)\n[\s\S]*?```|```[\s\S]*?$/g;
|
||||
const matches = content.match(codeBlockRegex);
|
||||
|
||||
|
||||
@@ -34,7 +34,8 @@ export const PROVIDERS: ProviderConfig[] = [
|
||||
providerName: LLMProviderName.ANTHROPIC,
|
||||
recommended: true,
|
||||
models: [
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
|
||||
{ name: "claude-opus-4-7", label: "Claude Opus 4.7", recommended: true },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6" },
|
||||
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
|
||||
],
|
||||
apiKeyPlaceholder: "sk-ant-...",
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
export interface BuildLlmSelection {
|
||||
providerName: string; // e.g., "build-mode-anthropic" (LLMProviderDescriptor.name)
|
||||
provider: string; // e.g., "anthropic"
|
||||
modelName: string; // e.g., "claude-opus-4-6"
|
||||
modelName: string; // e.g., "claude-opus-4-7"
|
||||
}
|
||||
|
||||
// Priority order for smart default LLM selection
|
||||
const LLM_SELECTION_PRIORITY = [
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-6" },
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-7" },
|
||||
{ provider: "openai", modelName: "gpt-5.2" },
|
||||
{ provider: "openrouter", modelName: "minimax/minimax-m2.1" },
|
||||
] as const;
|
||||
@@ -63,10 +63,11 @@ export function getDefaultLlmSelection(
|
||||
export const RECOMMENDED_BUILD_MODELS = {
|
||||
preferred: {
|
||||
provider: "anthropic",
|
||||
modelName: "claude-opus-4-6",
|
||||
displayName: "Claude Opus 4.6",
|
||||
modelName: "claude-opus-4-7",
|
||||
displayName: "Claude Opus 4.7",
|
||||
},
|
||||
alternatives: [
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-6" },
|
||||
{ provider: "anthropic", modelName: "claude-sonnet-4-6" },
|
||||
{ provider: "openai", modelName: "gpt-5.2" },
|
||||
{ provider: "openai", modelName: "gpt-5.1-codex" },
|
||||
@@ -148,7 +149,8 @@ export const BUILD_MODE_PROVIDERS: BuildModeProvider[] = [
|
||||
providerName: "anthropic",
|
||||
recommended: true,
|
||||
models: [
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
|
||||
{ name: "claude-opus-4-7", label: "Claude Opus 4.7", recommended: true },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6" },
|
||||
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
|
||||
],
|
||||
apiKeyPlaceholder: "sk-ant-...",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { SourceIcon } from "./SourceIcon";
|
||||
import { useState } from "react";
|
||||
import { OnyxIcon } from "./icons/icons";
|
||||
import { GithubIcon, OnyxIcon } from "./icons/icons";
|
||||
|
||||
export function WebResultIcon({
|
||||
url,
|
||||
@@ -23,6 +23,8 @@ export function WebResultIcon({
|
||||
<>
|
||||
{hostname.includes("onyx.app") ? (
|
||||
<OnyxIcon size={size} className="dark:text-[#fff] text-[#000]" />
|
||||
) : hostname === "github.com" || hostname.endsWith(".github.com") ? (
|
||||
<GithubIcon size={size} />
|
||||
) : !error ? (
|
||||
<img
|
||||
className="my-0 rounded-full py-0"
|
||||
|
||||
@@ -46,6 +46,7 @@ import freshdeskIcon from "@public/Freshdesk.png";
|
||||
import geminiSVG from "@public/Gemini.svg";
|
||||
import gitbookDarkIcon from "@public/GitBookDark.png";
|
||||
import gitbookLightIcon from "@public/GitBookLight.png";
|
||||
import githubDarkIcon from "@public/GithubDarkMode.png";
|
||||
import githubLightIcon from "@public/Github.png";
|
||||
import gongIcon from "@public/Gong.png";
|
||||
import googleIcon from "@public/Google.png";
|
||||
@@ -855,7 +856,7 @@ export const GitbookIcon = createLogoIcon(gitbookDarkIcon, {
|
||||
darkSrc: gitbookLightIcon,
|
||||
});
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true,
|
||||
darkSrc: githubDarkIcon,
|
||||
});
|
||||
export const GitlabIcon = createLogoIcon(gitlabIcon);
|
||||
export const GmailIcon = createLogoIcon(gmailIcon);
|
||||
|
||||
@@ -106,9 +106,23 @@ export default function useMultiModelChat(
|
||||
[currentLlmModel]
|
||||
);
|
||||
|
||||
const removeModel = useCallback((index: number) => {
|
||||
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
|
||||
}, []);
|
||||
const removeModel = useCallback(
|
||||
(index: number) => {
|
||||
const next = selectedModels.filter((_, i) => i !== index);
|
||||
// When dropping to single-model, switch llmManager to the surviving
|
||||
// model so it becomes the active model instead of reverting to the
|
||||
// user's default.
|
||||
if (next.length === 1 && next[0]) {
|
||||
llmManager.updateCurrentLlm({
|
||||
name: next[0].name,
|
||||
provider: next[0].provider,
|
||||
modelName: next[0].modelName,
|
||||
});
|
||||
}
|
||||
setSelectedModels(next);
|
||||
},
|
||||
[selectedModels, llmManager]
|
||||
);
|
||||
|
||||
const replaceModel = useCallback(
|
||||
(index: number, model: SelectedModel) => {
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
||||
// Fixed reveal rate — NOT adaptive. Any ceil(delta/N) formula produces
|
||||
// visible chunks on burst packet arrivals. 1 = 60 cps, 2 = 120 cps.
|
||||
const CHARS_PER_FRAME = 2;
|
||||
const CHARS_PER_FRAME = 3;
|
||||
|
||||
/**
|
||||
* Reveals `target` one character at a time on each animation frame.
|
||||
@@ -110,6 +110,23 @@ export function useTypewriter(target: string, enabled: boolean): string {
|
||||
}
|
||||
}, [target.length, displayedLength]);
|
||||
|
||||
// When the user navigates away and back (tab switch, window focus),
|
||||
// snap to all collected content so they see the full response immediately.
|
||||
useEffect(() => {
|
||||
const handleVisibility = () => {
|
||||
if (document.visibilityState === "visible") {
|
||||
const targetLen = targetRef.current.length;
|
||||
if (displayedLengthRef.current < targetLen) {
|
||||
displayedLengthRef.current = targetLen;
|
||||
setDisplayedLength(targetLen);
|
||||
}
|
||||
}
|
||||
};
|
||||
document.addEventListener("visibilitychange", handleVisibility);
|
||||
return () =>
|
||||
document.removeEventListener("visibilitychange", handleVisibility);
|
||||
}, []);
|
||||
|
||||
return useMemo(
|
||||
() => target.slice(0, Math.min(displayedLength, target.length)),
|
||||
[target, displayedLength]
|
||||
|
||||
@@ -173,8 +173,13 @@ function AttachmentItemLayout({
|
||||
rightChildren,
|
||||
}: AttachmentItemLayoutProps) {
|
||||
return (
|
||||
<Section flexDirection="row" gap={0.25} padding={0.25}>
|
||||
<div className={cn("h-[2.25rem] aspect-square rounded-08")}>
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="start"
|
||||
gap={0.25}
|
||||
padding={0.25}
|
||||
>
|
||||
<div className={cn("h-[2.25rem] aspect-square rounded-08 flex-shrink-0")}>
|
||||
<Section>
|
||||
<div
|
||||
className="attachment-button__icon-wrapper"
|
||||
@@ -189,6 +194,7 @@ function AttachmentItemLayout({
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
gap={1.5}
|
||||
className="min-w-0"
|
||||
>
|
||||
<div data-testid="attachment-item-title" className="flex-1 min-w-0">
|
||||
<Content
|
||||
|
||||
@@ -53,18 +53,17 @@ export class HTTPStreamingTTSPlayer {
|
||||
// Create abort controller for this request
|
||||
this.abortController = new AbortController();
|
||||
|
||||
// Build URL with query params
|
||||
const params = new URLSearchParams();
|
||||
params.set("text", text);
|
||||
if (voice) params.set("voice", voice);
|
||||
params.set("speed", speed.toString());
|
||||
|
||||
const url = `${this.getAPIUrl()}?${params}`;
|
||||
const url = this.getAPIUrl();
|
||||
const body = JSON.stringify({
|
||||
text,
|
||||
...(voice && { voice }),
|
||||
speed,
|
||||
});
|
||||
|
||||
// Check if MediaSource is supported
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
// Fallback to simple buffered playback
|
||||
return this.fallbackSpeak(url);
|
||||
return this.fallbackSpeak(url, body);
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
@@ -129,15 +128,21 @@ export class HTTPStreamingTTSPlayer {
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body,
|
||||
signal: this.abortController.signal,
|
||||
credentials: "include", // Include cookies for authentication
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`TTS request failed: ${response.status} - ${errorText}`
|
||||
);
|
||||
let message = `TTS request failed (${response.status})`;
|
||||
try {
|
||||
const errorJson = await response.json();
|
||||
if (errorJson.detail) message = errorJson.detail;
|
||||
} catch {
|
||||
// response wasn't JSON — use status text
|
||||
}
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
@@ -242,16 +247,24 @@ export class HTTPStreamingTTSPlayer {
|
||||
* Fallback for browsers that don't support MediaSource Extensions.
|
||||
* Buffers all audio before playing.
|
||||
*/
|
||||
private async fallbackSpeak(url: string): Promise<void> {
|
||||
private async fallbackSpeak(url: string, body: string): Promise<void> {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body,
|
||||
signal: this.abortController?.signal,
|
||||
credentials: "include", // Include cookies for authentication
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`TTS request failed: ${response.status} - ${errorText}`);
|
||||
let message = `TTS request failed (${response.status})`;
|
||||
try {
|
||||
const errorJson = await response.json();
|
||||
if (errorJson.detail) message = errorJson.detail;
|
||||
} catch {
|
||||
// response wasn't JSON — use status text
|
||||
}
|
||||
throw new Error(message);
|
||||
}
|
||||
|
||||
const audioData = await response.arrayBuffer();
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useState, useMemo, useRef } from "react";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { Button, SelectButton, OpenButton } from "@opal/components";
|
||||
import { Button, SelectButton } from "@opal/components";
|
||||
import { SvgPlusCircle, SvgX } from "@opal/icons";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { LLMOption } from "@/refresh-components/popovers/interfaces";
|
||||
@@ -109,7 +109,10 @@ export default function ModelSelector({
|
||||
onRemove(existingIndex);
|
||||
} else if (!atMax) {
|
||||
onAdd(model);
|
||||
setOpen(false);
|
||||
// Close the popover only when we've reached the max model count
|
||||
if (selectedModels.length + 1 >= MAX_MODELS) {
|
||||
setOpen(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -163,26 +166,13 @@ export default function ModelSelector({
|
||||
model.modelName
|
||||
);
|
||||
|
||||
if (!isMultiModel) {
|
||||
// Stable key — keying on model would unmount the pill
|
||||
// on change and leave Radix's anchorRef detached,
|
||||
// flashing the closing popover at (0,0).
|
||||
return (
|
||||
<OpenButton
|
||||
key="single-model-pill"
|
||||
icon={ProviderIcon}
|
||||
onClick={(e: React.MouseEvent) =>
|
||||
handlePillClick(index, e.currentTarget as HTMLElement)
|
||||
}
|
||||
>
|
||||
{model.displayName}
|
||||
</OpenButton>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
key={modelKey(model.provider, model.modelName)}
|
||||
key={
|
||||
isMultiModel
|
||||
? modelKey(model.provider, model.modelName)
|
||||
: "single-model-pill"
|
||||
}
|
||||
className="flex items-center"
|
||||
>
|
||||
{index > 0 && (
|
||||
@@ -194,23 +184,24 @@ export default function ModelSelector({
|
||||
)}
|
||||
<SelectButton
|
||||
icon={ProviderIcon}
|
||||
rightIcon={SvgX}
|
||||
rightIcon={isMultiModel ? SvgX : undefined}
|
||||
state="empty"
|
||||
variant="select-tinted"
|
||||
interaction="hover"
|
||||
variant="select-input"
|
||||
size="lg"
|
||||
onClick={(e: React.MouseEvent) => {
|
||||
const target = e.target as HTMLElement;
|
||||
const btn = e.currentTarget as HTMLElement;
|
||||
const icons = btn.querySelectorAll(
|
||||
".interactive-foreground-icon"
|
||||
);
|
||||
const lastIcon = icons[icons.length - 1];
|
||||
if (lastIcon && lastIcon.contains(target)) {
|
||||
onRemove(index);
|
||||
} else {
|
||||
handlePillClick(index, btn);
|
||||
if (isMultiModel) {
|
||||
const target = e.target as HTMLElement;
|
||||
const btn = e.currentTarget as HTMLElement;
|
||||
const icons = btn.querySelectorAll(
|
||||
".interactive-foreground-icon"
|
||||
);
|
||||
const lastIcon = icons[icons.length - 1];
|
||||
if (lastIcon && lastIcon.contains(target)) {
|
||||
onRemove(index);
|
||||
return;
|
||||
}
|
||||
}
|
||||
handlePillClick(index, e.currentTarget as HTMLElement);
|
||||
}}
|
||||
>
|
||||
{model.displayName}
|
||||
@@ -224,7 +215,7 @@ export default function ModelSelector({
|
||||
</div>
|
||||
|
||||
{!(atMax && replacingIndex === null) && (
|
||||
<Popover.Content side="top" align="end" width="lg">
|
||||
<Popover.Content side="top" align="end" width="xl">
|
||||
<ModelListContent
|
||||
llmProviders={llmManager.llmProviders}
|
||||
isLoading={llmManager.isLoadingProviders}
|
||||
|
||||
@@ -184,20 +184,18 @@ export function FileCard({
|
||||
}
|
||||
>
|
||||
<div className="min-w-0 max-w-[12rem]">
|
||||
<Interactive.Container border heightVariant="fit">
|
||||
<div className="[&_.opal-content-md-title-row]:min-w-0 [&_.opal-content-md-title]:break-all">
|
||||
<AttachmentItemLayout
|
||||
icon={isProcessing ? SimpleLoader : SvgFileText}
|
||||
title={file.name}
|
||||
description={
|
||||
isProcessing
|
||||
? file.status === UserFileStatus.UPLOADING
|
||||
? "Uploading..."
|
||||
: "Processing..."
|
||||
: typeLabel
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<Interactive.Container border heightVariant="fit" widthVariant="full">
|
||||
<AttachmentItemLayout
|
||||
icon={isProcessing ? SimpleLoader : SvgFileText}
|
||||
title={file.name}
|
||||
description={
|
||||
isProcessing
|
||||
? file.status === UserFileStatus.UPLOADING
|
||||
? "Uploading..."
|
||||
: "Processing..."
|
||||
: typeLabel
|
||||
}
|
||||
/>
|
||||
<Spacer horizontal rem={0.5} />
|
||||
</Interactive.Container>
|
||||
</div>
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
mergeFetchedModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
@@ -119,7 +120,13 @@ function BedrockModalInternals({
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
mergeFetchedModelConfigurations(
|
||||
models,
|
||||
formikProps.values.model_configurations
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
mergeFetchedModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
@@ -55,7 +56,13 @@ function BifrostModalInternals({
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
mergeFetchedModelConfigurations(
|
||||
models,
|
||||
formikProps.values.model_configurations
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues as BaseLLMModalValues,
|
||||
mergeFetchedModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
@@ -61,7 +62,13 @@ function LMStudioModalInternals({
|
||||
if (data.error) {
|
||||
throw new Error(data.error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", data.models);
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
mergeFetchedModelConfigurations(
|
||||
data.models,
|
||||
formikProps.values.model_configurations
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
mergeFetchedModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
@@ -57,7 +58,13 @@ function LiteLLMProxyModalInternals({
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
mergeFetchedModelConfigurations(
|
||||
models,
|
||||
formikProps.values.model_configurations
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
mergeFetchedModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
@@ -83,7 +84,13 @@ function OllamaModalInternals({
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
mergeFetchedModelConfigurations(
|
||||
models,
|
||||
formikProps.values.model_configurations
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
mergeFetchedModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
@@ -55,7 +56,13 @@ function OpenAICompatibleModalInternals({
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
mergeFetchedModelConfigurations(
|
||||
models,
|
||||
formikProps.values.model_configurations
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
mergeFetchedModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
@@ -49,7 +50,7 @@ function OpenRouterModalInternals({
|
||||
!formikProps.values.api_base || !formikProps.values.api_key;
|
||||
|
||||
const handleFetchModels = async () => {
|
||||
const { models, error } = await fetchOpenRouterModels({
|
||||
const { models: fetched, error } = await fetchOpenRouterModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
@@ -57,7 +58,13 @@ function OpenRouterModalInternals({
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
mergeFetchedModelConfigurations(
|
||||
fetched,
|
||||
formikProps.values.model_configurations
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -123,6 +123,30 @@ export interface BaseLLMFormValues {
|
||||
custom_config?: Record<string, string>;
|
||||
}
|
||||
|
||||
// ─── mergeFetchedModelConfigurations ──────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Merges a freshly-fetched model list with the current form state so that
|
||||
* refreshing the model list does not clobber the user's selections.
|
||||
*
|
||||
* - If the form has no models yet (first fetch / onboarding), the fetched
|
||||
* list is returned as-is so each provider's own default `is_visible` applies.
|
||||
* - Otherwise, models that already exist in the form keep their prior
|
||||
* `is_visible` value, and newly-discovered models are added unselected so
|
||||
* the user can opt-in explicitly.
|
||||
*/
|
||||
export function mergeFetchedModelConfigurations(
|
||||
fetched: ModelConfiguration[],
|
||||
existing: ModelConfiguration[]
|
||||
): ModelConfiguration[] {
|
||||
if (existing.length === 0) return fetched;
|
||||
const priorByName = new Map(existing.map((m) => [m.name, m]));
|
||||
return fetched.map((model) => {
|
||||
const prior = priorByName.get(model.name);
|
||||
return { ...model, is_visible: prior ? prior.is_visible : false };
|
||||
});
|
||||
}
|
||||
|
||||
// ─── Misc ─────────────────────────────────────────────────────────────────
|
||||
|
||||
export type TestApiKeyResult =
|
||||
|
||||
Reference in New Issue
Block a user