Compare commits

...

37 Commits

Author SHA1 Message Date
github-actions[bot]
9862b0ef59 fix(logos): github logo displays correctly in dark mode (#10269) to release v3.2 (#10284)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-16 15:10:14 -07:00
github-actions[bot]
8a7aeb2c59 feat(anthropic): include Opus 4.7 in recommended models (#10273) to release v3.2 (#10280)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-16 14:33:10 -07:00
github-actions[bot]
648dcd1e47 feat(img): Editing User Uploaded Images (#10264) to release v3.2 (#10278)
Co-authored-by: Danelegend <43459662+Danelegend@users.noreply.github.com>
2026-04-16 13:46:23 -07:00
Nikolas Garza
f73796928c fix(chat): only header click selects preferred in multi-model panels (#10198) to release v3.2 (#10234) 2026-04-15 14:37:10 -07:00
github-actions[bot]
91101e8f2c fix(chat): keep model selector popover open until max models reached (#10203) to release v3.2 (#10216)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-15 14:34:46 -07:00
github-actions[bot]
44bb3ded44 fix(chat): fix fade gradient missing on last multi-model panel (#10199) to release v3.2 (#10214)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-15 14:34:38 -07:00
github-actions[bot]
493e3f23b8 fix(chat): disable hover/pointer states on multi-model panels during streaming (#10202) to release v3.2 (#10215)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-15 14:32:32 -07:00
github-actions[bot]
031c1118bd fix(chat): snap typewriter to full content on tab re-focus (#10226) to release v3.2 (#10231)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-15 14:27:55 -07:00
github-actions[bot]
b8b7702f28 fix(chat): hide incomplete citation links during streaming (#10224) to release v3.2 (#10232)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-15 14:27:44 -07:00
github-actions[bot]
ebb67aede9 fix(voice): send TTS text in POST body instead of query params (#10213) to release v3.2 (#10221)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-15 10:31:07 -07:00
github-actions[bot]
340cd520eb fix(ollama): always include model tag in display name (#10218) to release v3.2 (#10219)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-04-15 09:22:53 -07:00
github-actions[bot]
b626ad232c fix(fe): handle file attachment overflow (#10211) to release v3.2 (#10212)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-14 19:21:25 -07:00
github-actions[bot]
f1ee9c12c0 fix(chat): render inline citation chips in multi-model panels (#10196) to release v3.2 (#10201)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-14 16:07:23 -07:00
github-actions[bot]
378cbedaa1 fix(chat): eliminate long-lived DB session in multi-model worker threads (#10159) to release v3.2 (#10191)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-14 14:16:36 -07:00
Alex Kim
f87e03b194 Add Datadog admission opt-out label to sandbox pods (#10040) 2026-04-14 14:00:32 -07:00
github-actions[bot]
873636a095 fix(chat): speed up text gen (#10186) to release v3.2 (#10187)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-14 13:43:15 -07:00
Justin Tahara
efb194e067 fix(llm): Fix the Auto Fetch workflow (#10181) 2026-04-14 11:16:30 -07:00
github-actions[bot]
3f7dfa7813 feat(notifications): announce upcoming group-based permissions migration (#10178) to release v3.2 (#10180)
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
2026-04-14 22:26:29 +05:30
Wenxi
5f08af3678 fix(google): handle JSON credential payloads in KV storage (@jack-larch) (#10160)
Co-authored-by: Jack Larch <jack.larch@biograph.com>
2026-04-13 18:35:51 -07:00
Nikolas Garza
1243af4f86 chore(hotfix): cherry-pick 2 commits to release v3.2 (#10140)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-04-13 14:12:33 -07:00
Nikolas Garza
91e84b8278 feat(chat): smooth character-level streaming (#10093) to release v3.2 (#10138) 2026-04-13 14:12:20 -07:00
Nikolas Garza
1d6baf10db feat(chat): scrollable tables with overflow fade (#10097) to release v3.2 (#10136) 2026-04-13 14:05:16 -07:00
github-actions[bot]
8d26357197 fix(chat): disable Deep Research in multi-model mode (ENG-4009) (#10126) to release v3.2 (#10139)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-13 14:04:36 -07:00
github-actions[bot]
cd43345415 fix: welcome message alignment in chrome extension/desktop (#10094) to release v3.2 (#10135)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-13 13:04:28 -07:00
github-actions[bot]
f99cf2f1b0 fix(chat): isolate multi-model streaming errors to their panels (#10113) to release v3.2 (#10127)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-13 12:49:20 -07:00
Jamison Lahman
7332adb1e6 fix(copy-button): fall back when Clipboard API unavailable (#10080) 2026-04-10 22:49:56 -07:00
Nikolas Garza
0ab1b76765 Revert "feat(chat): smooth character-level streaming (#10076) to release v3.2" (#10082) 2026-04-10 20:49:39 -07:00
github-actions[bot]
40cd0a78a3 feat(chat): smooth character-level streaming (#10076) to release v3.2 (#10081)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-10 20:41:49 -07:00
github-actions[bot]
28d8c5de46 fix(chat): model selection + multi-model follow-up correctness (#10075) to release v3.2 (#10078) 2026-04-10 17:25:00 -07:00
github-actions[bot]
004092767f fix(mcp): prevent masked OAuth credentials from being stored on re-auth (#10066) to release v3.2 (#10069)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-10 14:47:17 -07:00
Nikolas Garza
eb4689a669 fix(chat): hide ModelSelector in search mode (#10052) to release v3.2 (#10068) 2026-04-10 12:43:05 -07:00
github-actions[bot]
47dd8973c1 fix(scim): add advisory lock to prevent seat limit race condition (#10048) to release v3.2 (#10065)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-10 12:05:14 -07:00
github-actions[bot]
a1403ef78c feat(slack-bot): make agent selector searchable (#10036) to release v3.2 (#10038)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-10 12:04:51 -07:00
github-actions[bot]
f96b9d6804 fix(license): exclude service account users from seat count (#10053) to release v3.2 (#10061)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-10 12:04:30 -07:00
github-actions[bot]
711651276c fix(LLM config): resolve API Key before fetching models (#10056) to release v3.2 (#10057)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-10 00:02:33 -07:00
github-actions[bot]
3731110cf9 feat(federated): full thread replies + direct URL fetch in Slack search (#9940) to release v3.2 (#10050)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-09 18:24:02 -07:00
Evan Lohn
8fb7a8718e fix: jira bulk issue fetch batching (#10044) 2026-04-09 20:50:41 -04:00
88 changed files with 2805 additions and 1044 deletions

View File

@@ -13,6 +13,7 @@ from ee.onyx.server.license.models import LicenseSource
from onyx.auth.schemas import UserRole
from onyx.cache.factory import get_cache_backend
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.db.enums import AccountType
from onyx.db.models import License
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -107,12 +108,13 @@ def get_used_seats(tenant_id: str | None = None) -> int:
Get current seat usage directly from database.
For multi-tenant: counts users in UserTenantMapping for this tenant.
For self-hosted: counts all active users (excludes EXT_PERM_USER role
and the anonymous system user).
For self-hosted: counts all active users.
TODO: Exclude API key dummy users from seat counting. API keys create
users with emails like `__DANSWER_API_KEY_*` that should not count toward
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
Only human accounts count toward seat limits.
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
anonymous system user are excluded. BOT (Slack users) ARE counted
because they represent real humans and get upgraded to STANDARD
when they log in via web.
"""
if MULTI_TENANT:
from ee.onyx.server.tenants.user_mapping import get_tenant_count
@@ -129,6 +131,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
User.is_active == True, # type: ignore # noqa: E712
User.role != UserRole.EXT_PERM_USER,
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
User.account_type != AccountType.SERVICE_ACCOUNT,
)
)
return result.scalar() or 0

View File

@@ -11,6 +11,8 @@ require a valid SCIM bearer token.
from __future__ import annotations
import hashlib
import struct
from uuid import UUID
from fastapi import APIRouter
@@ -22,6 +24,7 @@ from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -65,12 +68,25 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Group names reserved for system default groups (seeded by migration).
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
# Namespace prefix for the seat-allocation advisory lock. Hashed together
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
# never block each other) and cannot collide with unrelated advisory locks.
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
return struct.unpack("q", digest[:8])[0]
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
@@ -209,12 +225,37 @@ def _apply_exclusions(
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
"""Return an error message if seat limit is reached, else None.
Acquires a transaction-scoped advisory lock so that concurrent
SCIM requests are serialized. IdPs like Okta send provisioning
requests in parallel batches — without serialization the check is
vulnerable to a TOCTOU race where N concurrent requests each see
"seats available", all insert, and the tenant ends up over its
seat limit.
The lock is held until the caller's next COMMIT or ROLLBACK, which
means the seat count cannot change between the check here and the
subsequent INSERT/UPDATE. Each call site in this module follows
the pattern: _check_seat_availability → write → dal.commit()
(which releases the lock for the next waiting request).
"""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
# The lock id is derived from the tenant so unrelated tenants never block
# each other, and from a namespace string so it cannot collide with
# unrelated advisory locks elsewhere in the codebase.
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
dal.session.execute(
text("SELECT pg_advisory_xact_lock(:lock_id)"),
{"lock_id": lock_id},
)
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"

View File

@@ -4,8 +4,6 @@ from collections.abc import Callable
from typing import Any
from typing import Literal
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import create_tool_call_failure_messages
from onyx.chat.citation_processor import CitationMapping
@@ -635,7 +633,6 @@ def run_llm_loop(
user_memory_context: UserMemoryContext | None,
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
forced_tool_id: int | None = None,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
@@ -1020,20 +1017,16 @@ def run_llm_loop(
persisted_memory_id: int | None = None
if user_memory_context and user_memory_context.user_id:
if tool_response.rich_response.index_to_replace is not None:
memory = update_memory_at_index(
persisted_memory_id = update_memory_at_index(
user_id=user_memory_context.user_id,
index=tool_response.rich_response.index_to_replace,
new_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id if memory else None
else:
memory = add_memory(
persisted_memory_id = add_memory(
user_id=user_memory_context.user_id,
memory_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id
operation: Literal["add", "update"] = (
"update"
if tool_response.rich_response.index_to_replace is not None

View File

@@ -826,6 +826,12 @@ def translate_history_to_llm_format(
base64_data = img_file.to_base64()
image_url = f"data:{image_type};base64,{base64_data}"
content_parts.append(
TextContentPart(
type="text",
text=f"[attached image — file_id: {img_file.file_id}]",
)
)
image_part = ImageContentPart(
type="image_url",
image_url=ImageUrlDetail(

View File

@@ -67,7 +67,6 @@ from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
@@ -1006,93 +1005,86 @@ def _run_models(
model_llm = setup.llms[model_idx]
try:
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
# Do NOT write to the outer db_session (or any shared DB state) from here;
# all DB writes in this thread must go through thread_db_session.
with get_session_with_current_tenant() as thread_db_session:
thread_tool_dict = construct_tools(
persona=setup.persona,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=setup.new_msg_req.internal_search_filters,
project_id_filter=setup.search_params.project_id_filter,
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
# Each function opens short-lived DB sessions on demand.
# Do NOT pass a long-lived session here — it would hold a
# connection for the entire LLM loop (minutes), and cloud
# infrastructure may drop idle connections.
thread_tool_dict = construct_tools(
persona=setup.persona,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=setup.new_msg_req.internal_search_filters,
project_id_filter=setup.search_params.project_id_filter,
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
message_id=setup.user_message.id,
additional_headers=setup.custom_tool_additional_headers,
mcp_headers=setup.mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=setup.available_files.user_file_ids,
chat_file_ids=setup.available_files.chat_file_ids,
),
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
message_id=setup.user_message.id,
additional_headers=setup.custom_tool_additional_headers,
mcp_headers=setup.mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=setup.available_files.user_file_ids,
chat_file_ids=setup.available_files.chat_file_ids,
),
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
)
model_tools = [
tool for tool_list in thread_tool_dict.values() for tool in tool_list
]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
model_tools = [
tool
for tool_list in thread_tool_dict.values()
for tool in tool_list
]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError(
"Deep research is not supported for projects"
)
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
skip_clarification=setup.skip_clarification,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
all_injected_file_metadata=setup.all_injected_file_metadata,
)
else:
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
context_files=setup.extracted_context_files,
persona=setup.persona,
user_memory_context=setup.user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=setup.forced_tool_id,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
chat_files=setup.chat_files_for_tools,
include_citations=setup.new_msg_req.include_citations,
all_injected_file_metadata=setup.all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError("Deep research is not supported for projects")
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
skip_clarification=setup.skip_clarification,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
all_injected_file_metadata=setup.all_injected_file_metadata,
)
else:
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
context_files=setup.extracted_context_files,
persona=setup.persona,
user_memory_context=setup.user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
forced_tool_id=setup.forced_tool_id,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
chat_files=setup.chat_files_for_tools,
include_citations=setup.new_msg_req.include_citations,
all_injected_file_metadata=setup.all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True

View File

@@ -1,4 +1,5 @@
import json
from typing import Any
from typing import cast
from urllib.parse import parse_qs
from urllib.parse import ParseResult
@@ -53,6 +54,21 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _load_google_json(raw: object) -> dict[str, Any]:
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
Payloads written before the fix for serializing Google credentials into
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
Once every install has re-uploaded their Google credentials the legacy
``str`` branch can be removed.
"""
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
return json.loads(raw)
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
@@ -162,12 +178,13 @@ def build_service_account_creds(
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
credential_json = _load_google_json(
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
)
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=GOOGLE_SCOPES[source],
@@ -188,12 +205,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleAppCredentials(**json.loads(creds_str))
return GoogleAppCredentials(**creds)
def upsert_google_app_cred(
@@ -201,10 +218,14 @@ def upsert_google_app_cred(
) -> None:
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
KV_GOOGLE_DRIVE_CRED_KEY,
app_credentials.model_dump(mode="json"),
encrypt=True,
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
get_kv_store().store(
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
)
else:
raise ValueError(f"Unsupported source: {source}")
@@ -220,12 +241,14 @@ def delete_google_app_cred(source: DocumentSource) -> None:
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
if source == DocumentSource.GOOGLE_DRIVE:
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
creds = _load_google_json(
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
)
elif source == DocumentSource.GMAIL:
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
else:
raise ValueError(f"Unsupported source: {source}")
return GoogleServiceAccountKey(**json.loads(creds_str))
return GoogleServiceAccountKey(**creds)
def upsert_service_account_key(
@@ -234,12 +257,14 @@ def upsert_service_account_key(
if source == DocumentSource.GOOGLE_DRIVE:
get_kv_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
service_account_key.json(),
service_account_key.model_dump(mode="json"),
encrypt=True,
)
elif source == DocumentSource.GMAIL:
get_kv_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
KV_GMAIL_SERVICE_ACCOUNT_KEY,
service_account_key.model_dump(mode="json"),
encrypt=True,
)
else:
raise ValueError(f"Unsupported source: {source}")

View File

@@ -60,8 +60,10 @@ logger = setup_logger()
ONE_HOUR = 3600
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
_MAX_RESULTS_FETCH_IDS = 5000
_JIRA_FULL_PAGE_SIZE = 50
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
_JIRA_BULK_FETCH_LIMIT = 100
# Constants for Jira field names
_FIELD_REPORTER = "reporter"
@@ -255,15 +257,13 @@ def _bulk_fetch_request(
return resp.json()["issues"]
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
def _bulk_fetch_batch(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
try:
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
return _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
@@ -277,12 +277,25 @@ def bulk_fetch_issues(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
return left + right
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
raw_issues: list[dict[str, Any]] = []
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
try:
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise
return [
Issue(jira_client._options, jira_client._session, raw=issue)

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from datetime import datetime
from typing import TypedDict
@@ -6,6 +7,14 @@ from pydantic import BaseModel
from onyx.onyxbot.slack.models import ChannelType
@dataclass(frozen=True)
class DirectThreadFetch:
"""Request to fetch a Slack thread directly by channel and timestamp."""
channel_id: str
thread_ts: str
class ChannelMetadata(TypedDict):
"""Type definition for cached channel metadata."""

View File

@@ -19,6 +19,7 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
@@ -49,7 +50,6 @@ from onyx.server.federated.models import FederatedConnectorDetail
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
logger = setup_logger()
@@ -58,7 +58,6 @@ HIGHLIGHT_END_CHAR = "\ue001"
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
filtered_channels: list[str] # Channels filtered out during this query
def _fetch_thread_from_url(
thread_fetch: DirectThreadFetch,
access_token: str,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
"""Fetch a thread directly from a Slack URL via conversations.replies."""
channel_id = thread_fetch.channel_id
thread_ts = thread_fetch.thread_ts
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
ts=thread_ts,
)
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.warning(
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
)
return SlackQueryResult(messages=[], filtered_channels=[])
if not messages:
logger.warning(
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
)
return SlackQueryResult(messages=[], filtered_channels=[])
# Build thread text from all messages
thread_text = _build_thread_text(messages, access_token, None, slack_client)
# Get channel name from metadata cache or API
channel_name = "unknown"
if channel_metadata_dict and channel_id in channel_metadata_dict:
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
else:
try:
ch_response = slack_client.conversations_info(channel=channel_id)
ch_response.validate()
channel_info: dict[str, Any] = ch_response.get("channel", {})
channel_name = channel_info.get("name", "unknown")
except SlackApiError:
pass
# Build the SlackMessage
parent_msg = messages[0]
message_ts = parent_msg.get("ts", thread_ts)
username = parent_msg.get("user", "unknown_user")
parent_text = parent_msg.get("text", "")
snippet = (
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
).replace("\n", " ")
doc_time = datetime.fromtimestamp(float(message_ts))
decay_factor = DOC_TIME_DECAY
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
permalink = (
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
)
slack_message = SlackMessage(
document_id=f"{channel_id}_{message_ts}",
channel_id=channel_id,
message_id=message_ts,
thread_id=None, # Prevent double-enrichment in thread context fetch
link=permalink,
metadata={
"channel": channel_name,
"time": doc_time.isoformat(),
},
timestamp=doc_time,
recency_bias=recency_bias,
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
text=thread_text,
highlighted_texts=set(),
slack_score=100000.0, # High priority — user explicitly asked for this thread
)
logger.info(
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
)
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
def query_slack(
query_string: str,
access_token: str,
@@ -432,7 +519,6 @@ def query_slack(
available_channels: list[str] | None = None,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
# Check if query has channel override (user specified channels in query)
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
@@ -662,7 +748,6 @@ def _fetch_thread_context(
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# If not a thread, return original text as success
if thread_id is None:
@@ -695,62 +780,37 @@ def _fetch_thread_context(
if len(messages) <= 1:
return ThreadContextResult.success(message.text)
# Build thread text from thread starter + context window around matched message
thread_text = _build_thread_text(
messages, message_id, thread_id, access_token, team_id, slack_client
)
# Build thread text from thread starter + all replies
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
return ThreadContextResult.success(thread_text)
def _build_thread_text(
messages: list[dict[str, Any]],
message_id: str,
thread_id: str,
access_token: str,
team_id: str | None,
slack_client: WebClient,
) -> str:
"""Build the thread text from messages."""
"""Build thread text including all replies.
Includes the thread parent message followed by all replies in order.
"""
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# All messages after index 0 are replies
replies = messages[1:]
if not replies:
return thread_text
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
else:
message_id_idx = next(
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
)
if not message_id_idx:
return thread_text
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
if start_idx > 1:
thread_text += "\n..."
for i in range(start_idx, message_id_idx):
msg_text = messages[i].get("text", "")
msg_sender = messages[i].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add following replies
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
for msg in replies:
msg_text = msg.get("text", "")
msg_sender = msg.get("user", "")
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Replace user IDs with names using cached lookups
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
@@ -976,7 +1036,16 @@ def slack_retrieval(
# Query slack with entity filtering
llm = get_default_llm()
query_strings = build_slack_queries(query, llm, entities, available_channels)
query_items = build_slack_queries(query, llm, entities, available_channels)
# Partition into direct thread fetches and search query strings
direct_fetches: list[DirectThreadFetch] = []
query_strings: list[str] = []
for item in query_items:
if isinstance(item, DirectThreadFetch):
direct_fetches.append(item)
else:
query_strings.append(item)
# Determine filtering based on entities OR context (bot)
include_dm = False
@@ -993,8 +1062,16 @@ def slack_retrieval(
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
)
# Build search tasks
search_tasks = [
# Build search tasks — direct thread fetches + keyword searches
search_tasks: list[tuple] = [
(
_fetch_thread_from_url,
(fetch, access_token, channel_metadata_dict),
)
for fetch in direct_fetches
]
search_tasks.extend(
(
query_slack,
(
@@ -1010,7 +1087,7 @@ def slack_retrieval(
),
)
for query_string in query_strings
]
)
# If include_dm is True AND we're not already searching all channels,
# add additional searches without channel filters.

View File

@@ -10,6 +10,7 @@ from pydantic import ValidationError
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.models import ChunkIndexRequest
from onyx.federated_connectors.slack.models import SlackEntities
from onyx.llm.interfaces import LLM
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
return [query_text]
SLACK_URL_PATTERN = re.compile(
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
)
def extract_slack_message_urls(
query_text: str,
) -> list[tuple[str, str]]:
"""Extract Slack message URLs from query text.
Parses URLs like:
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
Returns list of (channel_id, thread_ts) tuples.
The 16-digit timestamp is converted to Slack ts format (with dot).
"""
results = []
for match in SLACK_URL_PATTERN.finditer(query_text):
channel_id = match.group(1)
raw_ts = match.group(2)
# Convert p1775491616524769 -> 1775491616.524769
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
results.append((channel_id, thread_ts))
return results
def build_slack_queries(
query: ChunkIndexRequest,
llm: LLM,
entities: dict[str, Any] | None = None,
available_channels: list[str] | None = None,
) -> list[str]:
) -> list[str | DirectThreadFetch]:
"""Build Slack query strings with date filtering and query expansion."""
default_search_days = 30
if entities:
@@ -668,6 +695,15 @@ def build_slack_queries(
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
# Check for Slack message URLs — if found, add direct fetch requests
url_fetches: list[DirectThreadFetch] = []
slack_urls = extract_slack_message_urls(query.query)
for channel_id, thread_ts in slack_urls:
url_fetches.append(
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
)
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
# ALWAYS extract channel references from the query (not just for recency queries)
channel_references = extract_channel_references_from_query(query.query)
@@ -684,7 +720,9 @@ def build_slack_queries(
# If valid channels detected, use ONLY those channels with NO keywords
# Return query with ONLY time filter + channel filter (no keywords)
return [build_channel_override_query(channel_references, time_filter)]
return url_fetches + [
build_channel_override_query(channel_references, time_filter)
]
except ValueError as e:
# If validation fails, log the error and continue with normal flow
logger.warning(f"Channel reference validation failed: {e}")
@@ -702,7 +740,8 @@ def build_slack_queries(
rephrased_queries = expand_query_with_llm(query.query, llm)
# Build final query strings with time filters
return [
search_queries = [
rephrased_query.strip() + time_filter
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
]
return url_fetches + search_queries

View File

@@ -11,6 +11,7 @@ from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_READONLY_PASSWORD
@@ -346,6 +347,25 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _safe_close_session(session: Session) -> None:
"""Close a session, catching connection-closed errors during cleanup.
Long-running operations (e.g. multi-model LLM loops) can hold a session
open for minutes. If the underlying connection is dropped by cloud
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
timeouts, etc.), the implicit rollback in Session.close() raises
OperationalError or InterfaceError. Since the work is already complete,
we log and move on — SQLAlchemy internally invalidates the connection
for pool recycling.
"""
try:
session.close()
except DBAPIError:
logger.warning(
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
)
@contextmanager
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
"""
@@ -358,8 +378,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
with Session(bind=engine, expire_on_commit=False) as session:
session = Session(bind=engine, expire_on_commit=False)
try:
yield session
finally:
_safe_close_session(session)
return
# Create connection with schema translation to handle querying the right schema
@@ -367,8 +390,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
with engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
with Session(bind=connection, expire_on_commit=False) as session:
session = Session(bind=connection, expire_on_commit=False)
try:
yield session
finally:
_safe_close_session(session)
def get_session() -> Generator[Session, None, None]:

View File

@@ -5,6 +5,7 @@ from pydantic import ConfigDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.models import Memory
from onyx.db.models import User
@@ -83,47 +84,51 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
def add_memory(
user_id: UUID,
memory_text: str,
db_session: Session,
) -> Memory:
db_session: Session | None = None,
) -> int:
"""Insert a new Memory row for the given user.
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
one (lowest id) is deleted before inserting the new one.
Returns the id of the newly created Memory row.
"""
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
with get_session_with_current_tenant_if_none(db_session) as db_session:
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory.id
def update_memory_at_index(
user_id: UUID,
index: int,
new_text: str,
db_session: Session,
) -> Memory | None:
db_session: Session | None = None,
) -> int | None:
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
Returns the updated Memory row, or None if the index is out of range.
Returns the id of the updated Memory row, or None if the index is out of range.
"""
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
with get_session_with_current_tenant_if_none(db_session) as db_session:
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if index < 0 or index >= len(memory_rows):
return None
if index < 0 or index >= len(memory_rows):
return None
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory.id

View File

@@ -7,8 +7,6 @@ import time
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.citation_processor import DynamicCitationProcessor
@@ -22,6 +20,7 @@ from onyx.chat.models import LlmStepResult
from onyx.chat.models import ToolCallSimple
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
from onyx.configs.constants import MessageType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tools import get_tool_by_name
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
@@ -184,6 +183,14 @@ def generate_final_report(
return has_reasoned
def _get_research_agent_tool_id() -> int:
with get_session_with_current_tenant() as db_session:
return get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id
@log_function_time(print_only=True)
def run_deep_research_llm_loop(
emitter: Emitter,
@@ -193,7 +200,6 @@ def run_deep_research_llm_loop(
custom_agent_prompt: str | None, # noqa: ARG001
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
skip_clarification: bool = False,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
@@ -717,6 +723,7 @@ def run_deep_research_llm_loop(
simple_chat_history.append(assistant_with_tools)
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
research_agent_tool_id = _get_research_agent_tool_id()
for tab_index, report in enumerate(
research_results.intermediate_reports
):
@@ -737,10 +744,7 @@ def run_deep_research_llm_loop(
tab_index=tab_index,
tool_name=current_tool_call.tool_name,
tool_call_id=current_tool_call.tool_call_id,
tool_id=get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id,
tool_id=research_agent_tool_id,
reasoning_tokens=llm_step_result.reasoning
or most_recent_reasoning,
tool_call_arguments=current_tool_call.tool_args,

View File

@@ -1516,6 +1516,10 @@
"display_name": "Claude Opus 4.6",
"model_vendor": "anthropic"
},
"claude-opus-4-7": {
"display_name": "Claude Opus 4.7",
"model_vendor": "anthropic"
},
"claude-opus-4-5-20251101": {
"display_name": "Claude Opus 4.5",
"model_vendor": "anthropic",

View File

@@ -46,6 +46,15 @@ ANTHROPIC_REASONING_EFFORT_BUDGET: dict[ReasoningEffort, int] = {
ReasoningEffort.HIGH: 4096,
}
# Newer Anthropic models (Claude Opus 4.7+) use adaptive thinking with
# output_config.effort instead of thinking.type.enabled + budget_tokens.
ANTHROPIC_ADAPTIVE_REASONING_EFFORT: dict[ReasoningEffort, str] = {
ReasoningEffort.AUTO: "medium",
ReasoningEffort.LOW: "low",
ReasoningEffort.MEDIUM: "medium",
ReasoningEffort.HIGH: "high",
}
# Content part structures for multimodal messages
# The classes in this mirror the OpenAI Chat Completions message types and work well with routers like LiteLLM

View File

@@ -23,6 +23,7 @@ from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.model_response import ModelResponse
from onyx.llm.model_response import ModelResponseStream
from onyx.llm.model_response import Usage
from onyx.llm.models import ANTHROPIC_ADAPTIVE_REASONING_EFFORT
from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET
from onyx.llm.models import OPENAI_REASONING_EFFORT
from onyx.llm.request_context import get_llm_mock_response
@@ -67,8 +68,13 @@ STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
_VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG = (
"claude-opus-4-5",
"claude-opus-4-6",
"claude-opus-4-7",
)
# Anthropic models that require the adaptive thinking API (thinking.type.adaptive
# + output_config.effort) instead of the legacy thinking.type.enabled + budget_tokens.
_ANTHROPIC_ADAPTIVE_THINKING_MODELS = ("claude-opus-4-7",)
class LLMTimeoutError(Exception):
"""
@@ -230,6 +236,14 @@ def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
)
def _anthropic_uses_adaptive_thinking(model_name: str) -> bool:
normalized_model_name = model_name.lower()
return any(
adaptive_model in normalized_model_name
for adaptive_model in _ANTHROPIC_ADAPTIVE_THINKING_MODELS
)
class LitellmLLM(LLM):
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
See https://python.langchain.com/docs/integrations/chat/litellm"""
@@ -509,10 +523,6 @@ class LitellmLLM(LLM):
}
elif is_claude_model:
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
reasoning_effort
)
# Anthropic requires every assistant message with tool_use
# blocks to start with a thinking block that carries a
# cryptographic signature. We don't preserve those blocks
@@ -520,24 +530,35 @@ class LitellmLLM(LLM):
# contains tool-calling assistant messages. LiteLLM's
# modify_params workaround doesn't cover all providers
# (notably Bedrock).
can_enable_thinking = (
budget_tokens is not None
and not _prompt_contains_tool_call_history(prompt)
)
has_tool_call_history = _prompt_contains_tool_call_history(prompt)
if can_enable_thinking:
assert budget_tokens is not None # mypy
if max_tokens is not None:
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
# and the minimum budget tokens is 1024
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
# call as compared to reducing the budget for reasoning.
max_tokens = max(budget_tokens + 1, max_tokens)
optional_kwargs["thinking"] = {
"type": "enabled",
"budget_tokens": budget_tokens,
}
if _anthropic_uses_adaptive_thinking(self.config.model_name):
# Newer Anthropic models (Claude Opus 4.7+) reject
# thinking.type.enabled — they require the adaptive
# thinking config with output_config.effort.
if not has_tool_call_history:
optional_kwargs["thinking"] = {"type": "adaptive"}
optional_kwargs["output_config"] = {
"effort": ANTHROPIC_ADAPTIVE_REASONING_EFFORT[
reasoning_effort
],
}
else:
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
reasoning_effort
)
if budget_tokens is not None and not has_tool_call_history:
if max_tokens is not None:
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
# and the minimum budget tokens is 1024
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
# call as compared to reducing the budget for reasoning.
max_tokens = max(budget_tokens + 1, max_tokens)
optional_kwargs["thinking"] = {
"type": "enabled",
"budget_tokens": budget_tokens,
}
# LiteLLM just does some mapping like this anyway but is incomplete for Anthropic
optional_kwargs.pop("reasoning_effort", None)

View File

@@ -1,6 +1,6 @@
{
"version": "1.1",
"updated_at": "2026-03-05T00:00:00Z",
"version": "1.2",
"updated_at": "2026-04-16T00:00:00Z",
"providers": {
"openai": {
"default_model": { "name": "gpt-5.4" },
@@ -10,8 +10,12 @@
]
},
"anthropic": {
"default_model": "claude-opus-4-6",
"default_model": "claude-opus-4-7",
"additional_visible_models": [
{
"name": "claude-opus-4-7",
"display_name": "Claude Opus 4.7"
},
{
"name": "claude-opus-4-6",
"display_name": "Claude Opus 4.6"

View File

@@ -65,8 +65,9 @@ IMPORTANT: each call to this tool is independent. Variables from previous calls
GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
the `file_id` values returned by earlier `generate_image` tool results.
To edit, restyle, or vary an existing image, pass its file_id in `reference_image_file_ids`. \
File IDs come from `[attached image — file_id: <id>]` tags on user-attached images or from prior `generate_image` tool results — never invent one. \
Leave `reference_image_file_ids` unset for a fresh generation.
""".lstrip()
MEMORY_GUIDANCE = """

View File

@@ -618,6 +618,7 @@ done
"app.kubernetes.io/managed-by": "onyx",
"onyx.app/sandbox-id": sandbox_id,
"onyx.app/tenant-id": tenant_id,
"admission.datadoghq.com/enabled": "false",
},
),
spec=pod_spec,

View File

@@ -96,6 +96,32 @@ def _truncate_description(description: str | None, max_length: int = 500) -> str
return description[: max_length - 3] + "..."
# TODO: Replace mask-comparison approach with an explicit Unset sentinel from the
# frontend indicating whether each credential field was actually modified. The current
# approach is brittle (e.g. short credentials produce a fixed-length mask that could
# collide) and mutates request values, which is surprising. The frontend should signal
# "unchanged" vs "new value" directly rather than relying on masked-string equality.
def _restore_masked_oauth_credentials(
request_client_id: str | None,
request_client_secret: str | None,
existing_client: OAuthClientInformationFull,
) -> tuple[str | None, str | None]:
"""If the frontend sent back masked credentials, restore the real stored values."""
if (
request_client_id
and existing_client.client_id
and request_client_id == mask_string(existing_client.client_id)
):
request_client_id = existing_client.client_id
if (
request_client_secret
and existing_client.client_secret
and request_client_secret == mask_string(existing_client.client_secret)
):
request_client_secret = existing_client.client_secret
return request_client_id, request_client_secret
router = APIRouter(prefix="/mcp")
admin_router = APIRouter(prefix="/admin/mcp")
STATE_TTL_SECONDS = 60 * 5 # 5 minutes
@@ -392,6 +418,26 @@ async def _connect_oauth(
detail=f"Server was configured with authentication type {auth_type_str}",
)
# If the frontend sent back masked credentials (unchanged by the user),
# restore the real stored values so we don't overwrite them with masks.
if mcp_server.admin_connection_config:
existing_data = extract_connection_data(
mcp_server.admin_connection_config, apply_mask=False
)
existing_client_raw = existing_data.get(MCPOAuthKeys.CLIENT_INFO.value)
if existing_client_raw:
existing_client = OAuthClientInformationFull.model_validate(
existing_client_raw
)
(
request.oauth_client_id,
request.oauth_client_secret,
) = _restore_masked_oauth_credentials(
request.oauth_client_id,
request.oauth_client_secret,
existing_client,
)
# Create admin config with client info if provided
config_data = MCPConnectionData(headers={})
if request.oauth_client_id and request.oauth_client_secret:
@@ -1356,6 +1402,19 @@ def _upsert_mcp_server(
if client_info_raw:
client_info = OAuthClientInformationFull.model_validate(client_info_raw)
# If the frontend sent back masked credentials (unchanged by the user),
# restore the real stored values so the comparison below sees no change
# and the credentials aren't overwritten with masked strings.
if client_info and request.auth_type == MCPAuthenticationType.OAUTH:
(
request.oauth_client_id,
request.oauth_client_secret,
) = _restore_masked_oauth_credentials(
request.oauth_client_id,
request.oauth_client_secret,
client_info,
)
changing_connection_config = (
not mcp_server.admin_connection_config
or (

View File

@@ -11,6 +11,9 @@ from onyx.db.notification import dismiss_notification
from onyx.db.notification import get_notification_by_id
from onyx.db.notification import get_notifications
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
from onyx.server.features.notifications.utils import (
ensure_permissions_migration_notification,
)
from onyx.server.features.release_notes.utils import (
ensure_release_notes_fresh_and_notify,
)
@@ -49,6 +52,13 @@ def get_notifications_api(
except Exception:
logger.exception("Failed to check for release notes in notifications endpoint")
try:
ensure_permissions_migration_notification(user, db_session)
except Exception:
logger.exception(
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
)
notifications = [
NotificationModel.from_model(notif)
for notif in get_notifications(user, db_session, include_dismissed=True)

View File

@@ -0,0 +1,21 @@
from sqlalchemy.orm import Session
from onyx.configs.constants import NotificationType
from onyx.db.models import User
from onyx.db.notification import create_notification
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
# Feature id "permissions_migration_v1" must not change after shipping —
# it is the dedup key on (user_id, notif_type, additional_data).
create_notification(
user_id=user.id,
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
db_session=db_session,
title="Permissions are changing in Onyx",
description="Roles are moving to group-based permissions. Click for details.",
additional_data={
"feature": "permissions_migration_v1",
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
},
)

View File

@@ -111,6 +111,43 @@ def _mask_string(value: str) -> str:
return value[:4] + "****" + value[-4:]
def _resolve_api_key(
api_key: str | None,
provider_name: str | None,
api_base: str | None,
db_session: Session,
) -> str | None:
"""Return the real API key for model-fetch endpoints.
When editing an existing provider the form value is masked (e.g.
``sk-a****b1c2``). If *provider_name* is supplied we can look up
the unmasked key from the database so the external request succeeds.
The stored key is only returned when the request's *api_base*
matches the value stored in the database.
"""
if not provider_name:
return api_key
existing_provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
if existing_provider and existing_provider.api_key:
# Normalise both URLs before comparing so trailing-slash
# differences don't cause a false mismatch.
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
request_base = (api_base or "").strip().rstrip("/")
if stored_base != request_base:
return api_key
stored_key = existing_provider.api_key.get_value(apply_mask=False)
# Only resolve when the incoming value is the masked form of the
# stored key — i.e. the user hasn't typed a new key.
if api_key and api_key == _mask_string(stored_key):
return stored_key
return api_key
def _sync_fetched_models(
db_session: Session,
provider_name: str,
@@ -1174,16 +1211,17 @@ def get_ollama_available_models(
return sorted_results
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
"""Perform GET to OpenRouter /models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/models"
headers = {
"Authorization": f"Bearer {api_key}",
headers: dict[str, str] = {
# Optional headers recommended by OpenRouter for attribution
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
response = httpx.get(url, headers=headers, timeout=10.0)
response.raise_for_status()
@@ -1206,8 +1244,12 @@ def get_openrouter_available_models(
Parses id, name (display), context_length, and architecture.input_modalities.
"""
api_key = _resolve_api_key(
request.api_key, request.provider_name, request.api_base, db_session
)
response_json = _get_openrouter_models_response(
api_base=request.api_base, api_key=request.api_key
api_base=request.api_base, api_key=api_key
)
data = response_json.get("data", [])
@@ -1300,13 +1342,18 @@ def get_lm_studio_available_models(
# If provider_name is given and the api_key hasn't been changed by the user,
# fall back to the stored API key from the database (the form value is masked).
# Only do so when the api_base matches what is stored.
api_key = request.api_key
if request.provider_name and not request.api_key_changed:
existing_provider = fetch_existing_llm_provider(
name=request.provider_name, db_session=db_session
)
if existing_provider and existing_provider.custom_config:
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
if stored_base == cleaned_api_base:
api_key = existing_provider.custom_config.get(
LM_STUDIO_API_KEY_CONFIG_KEY
)
url = f"{cleaned_api_base}/api/v1/models"
headers: dict[str, str] = {}
@@ -1390,8 +1437,12 @@ def get_litellm_available_models(
db_session: Session = Depends(get_session),
) -> list[LitellmFinalModelResponse]:
"""Fetch available models from Litellm proxy /v1/models endpoint."""
api_key = _resolve_api_key(
request.api_key, request.provider_name, request.api_base, db_session
)
response_json = _get_litellm_models_response(
api_key=request.api_key, api_base=request.api_base
api_key=api_key, api_base=request.api_base
)
models = response_json.get("data", [])
@@ -1448,7 +1499,7 @@ def get_litellm_available_models(
return sorted_results
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
@@ -1523,8 +1574,12 @@ def get_bifrost_available_models(
db_session: Session = Depends(get_session),
) -> list[BifrostFinalModelResponse]:
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
api_key = _resolve_api_key(
request.api_key, request.provider_name, request.api_base, db_session
)
response_json = _get_bifrost_models_response(
api_base=request.api_base, api_key=request.api_key
api_base=request.api_base, api_key=api_key
)
models = response_json.get("data", [])
@@ -1613,8 +1668,12 @@ def get_openai_compatible_server_available_models(
db_session: Session = Depends(get_session),
) -> list[OpenAICompatibleFinalModelResponse]:
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
api_key = _resolve_api_key(
request.api_key, request.provider_name, request.api_base, db_session
)
response_json = _get_openai_compatible_server_response(
api_base=request.api_base, api_key=request.api_key
api_base=request.api_base, api_key=api_key
)
models = response_json.get("data", [])

View File

@@ -183,6 +183,9 @@ def generate_ollama_display_name(model_name: str) -> str:
"qwen2.5:7b""Qwen 2.5 7B"
"mistral:latest""Mistral"
"deepseek-r1:14b""DeepSeek R1 14B"
"gemma4:e4b""Gemma 4 E4B"
"deepseek-v3.1:671b-cloud""DeepSeek V3.1 671B Cloud"
"qwen3-vl:235b-instruct-cloud""Qwen 3-vl 235B Instruct Cloud"
"""
# Split into base name and tag
if ":" in model_name:
@@ -209,13 +212,24 @@ def generate_ollama_display_name(model_name: str) -> str:
# Default: Title case with dashes converted to spaces
display_name = base.replace("-", " ").title()
# Process tag to extract size info (skip "latest")
# Process tag (skip "latest")
if tag and tag.lower() != "latest":
# Extract size like "7b", "70b", "14b"
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])", tag)
# Check for size prefix like "7b", "70b", optionally followed by modifiers
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])(-.+)?$", tag)
if size_match:
size = size_match.group(1).upper()
display_name = f"{display_name} {size}"
remainder = size_match.group(2)
if remainder:
# Format modifiers like "-cloud", "-instruct-cloud"
modifiers = " ".join(
p.title() for p in remainder.strip("-").split("-") if p
)
display_name = f"{display_name} {size} {modifiers}"
else:
display_name = f"{display_name} {size}"
else:
# Non-size tags like "e4b", "q4_0", "fp16", "cloud"
display_name = f"{display_name} {tag.upper()}"
return display_name

View File

@@ -1,13 +1,14 @@
import json
import secrets
from collections.abc import AsyncIterator
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Query
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
@@ -113,28 +114,47 @@ async def transcribe_audio(
) from exc
def _extract_provider_error(exc: Exception) -> str:
"""Extract a human-readable message from a provider exception.
Provider errors often embed JSON from upstream APIs (e.g. ElevenLabs).
This tries to parse a readable ``message`` field out of common JSON
error shapes; falls back to ``str(exc)`` if nothing better is found.
"""
raw = str(exc)
try:
# Many providers embed JSON after a prefix like "ElevenLabs TTS failed: {...}"
json_start = raw.find("{")
if json_start == -1:
return raw
parsed = json.loads(raw[json_start:])
# Shape: {"detail": {"message": "..."}} (ElevenLabs)
detail = parsed.get("detail", parsed)
if isinstance(detail, dict):
return detail.get("message") or detail.get("error") or raw
if isinstance(detail, str):
return detail
except (json.JSONDecodeError, AttributeError, TypeError):
pass
return raw
class SynthesizeRequest(BaseModel):
text: str = Field(..., min_length=1)
voice: str | None = None
speed: float | None = Field(default=None, ge=0.5, le=2.0)
@router.post("/synthesize")
async def synthesize_speech(
text: str | None = Query(
default=None, description="Text to synthesize", max_length=4096
),
voice: str | None = Query(default=None, description="Voice ID to use"),
speed: float | None = Query(
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
),
body: SynthesizeRequest,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
) -> StreamingResponse:
"""
Synthesize text to speech using the default TTS provider.
Accepts parameters via query string for streaming compatibility.
"""
logger.info(
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
)
if not text:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
"""Synthesize text to speech using the default TTS provider."""
text = body.text
voice = body.voice
speed = body.speed
logger.info(f"TTS request: text length={len(text)}, voice={voice}, speed={speed}")
# Use short-lived session to fetch provider config, then release connection
# before starting the long-running streaming response
@@ -177,31 +197,36 @@ async def synthesize_speech(
logger.error(f"Failed to get voice provider: {exc}")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
# Session is now closed - streaming response won't hold DB connection
# Pull the first chunk before returning the StreamingResponse. If the
# provider rejects the request (e.g. text too long), the error surfaces
# as a proper HTTP error instead of a broken audio stream.
stream_iter = provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
)
try:
first_chunk = await stream_iter.__anext__()
except StopAsyncIteration:
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "TTS provider returned no audio")
except Exception as exc:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY, _extract_provider_error(exc)
) from exc
async def audio_stream() -> AsyncIterator[bytes]:
try:
chunk_count = 0
async for chunk in provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
):
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
except NotImplementedError as exc:
logger.error(f"TTS not implemented: {exc}")
raise
except Exception as exc:
logger.error(f"Synthesis failed: {exc}")
raise
yield first_chunk
chunk_count = 1
async for chunk in stream_iter:
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
return StreamingResponse(
audio_stream(),
media_type="audio/mpeg",
headers={
"Content-Disposition": "inline; filename=speech.mp3",
# Allow streaming by not setting content-length
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable nginx buffering
"X-Accel-Buffering": "no",
},
)

View File

@@ -65,6 +65,7 @@ class Settings(BaseModel):
anonymous_user_enabled: bool | None = None
invite_only_enabled: bool = False
deep_research_enabled: bool | None = None
multi_model_chat_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Whether EE features are unlocked for use.
@@ -89,7 +90,8 @@ class Settings(BaseModel):
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
)
file_token_count_threshold_k: int | None = Field(
default=None, ge=0 # thousands of tokens; None = context-aware default
default=None,
ge=0, # thousands of tokens; None = context-aware default
)
# Connector settings

View File

@@ -208,12 +208,6 @@ class PythonToolOverrideKwargs(BaseModel):
chat_files: list[ChatFile] = []
class ImageGenerationToolOverrideKwargs(BaseModel):
"""Override kwargs for image generation tool calls."""
recent_generated_image_file_ids: list[str] = []
class SearchToolRunContext(BaseModel):
emitter: Emitter

View File

@@ -10,6 +10,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import PersonaSearchInfo
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.mcp import get_all_mcp_tools_for_server
@@ -113,10 +114,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
def construct_tools(
persona: Persona,
db_session: Session,
emitter: Emitter,
user: User,
llm: LLM,
db_session: Session | None = None,
search_tool_config: SearchToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
file_reader_tool_config: FileReaderToolConfig | None = None,
@@ -131,6 +132,33 @@ def construct_tools(
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
to avoid lazy SQL queries after the session may have been flushed."""
with get_session_with_current_tenant_if_none(db_session) as db_session:
return _construct_tools_impl(
persona=persona,
db_session=db_session,
emitter=emitter,
user=user,
llm=llm,
search_tool_config=search_tool_config,
custom_tool_config=custom_tool_config,
file_reader_tool_config=file_reader_tool_config,
allowed_tool_ids=allowed_tool_ids,
search_usage_forcing_setting=search_usage_forcing_setting,
)
def _construct_tools_impl(
persona: Persona,
db_session: Session,
emitter: Emitter,
user: User,
llm: LLM,
search_tool_config: SearchToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
file_reader_tool_config: FileReaderToolConfig | None = None,
allowed_tool_ids: list[int] | None = None,
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
) -> dict[int, list[Tool]]:
tool_dict: dict[int, list[Tool]] = {}
# Log which tools are attached to the persona for debugging

View File

@@ -26,7 +26,6 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import ImageGenerationToolOverrideKwargs
from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolExecutionException
from onyx.tools.models import ToolResponse
@@ -48,7 +47,7 @@ PROMPT_FIELD = "prompt"
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
class ImageGenerationTool(Tool[None]):
NAME = "generate_image"
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
DISPLAY_NAME = "Image Generation"
@@ -142,8 +141,11 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
REFERENCE_IMAGE_FILE_IDS_FIELD: {
"type": "array",
"description": (
"Optional image file IDs to use as reference context for edits/variations. "
"Use the file_id values returned by previous generate_image calls."
"Optional file_ids of existing images to edit or use as reference;"
" the first is the primary edit source."
" Get file_ids from `[attached image — file_id: <id>]` tags on"
" user-attached images or from prior generate_image tool responses."
" Omit for a fresh, unrelated generation."
),
"items": {
"type": "string",
@@ -254,41 +256,31 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
def _resolve_reference_image_file_ids(
self,
llm_kwargs: dict[str, Any],
override_kwargs: ImageGenerationToolOverrideKwargs | None,
) -> list[str]:
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
if raw_reference_ids is not None:
if not isinstance(raw_reference_ids, list) or not all(
isinstance(file_id, str) for file_id in raw_reference_ids
):
raise ToolCallException(
message=(
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
),
llm_facing_message=(
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
),
)
reference_image_file_ids = [
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
]
elif (
override_kwargs
and override_kwargs.recent_generated_image_file_ids
and self.img_provider.supports_reference_images
):
# If no explicit reference was provided, default to the most recently generated image.
reference_image_file_ids = [
override_kwargs.recent_generated_image_file_ids[-1]
]
else:
reference_image_file_ids = []
if raw_reference_ids is None:
# No references requested — plain generation.
return []
# Deduplicate while preserving order.
if not isinstance(raw_reference_ids, list) or not all(
isinstance(file_id, str) for file_id in raw_reference_ids
):
raise ToolCallException(
message=(
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
),
llm_facing_message=(
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
),
)
# Deduplicate while preserving order (first occurrence wins, so the
# LLM's intended "primary edit source" stays at index 0).
deduped_reference_image_ids: list[str] = []
seen_ids: set[str] = set()
for file_id in reference_image_file_ids:
if file_id in seen_ids:
for file_id in raw_reference_ids:
file_id = file_id.strip()
if not file_id or file_id in seen_ids:
continue
seen_ids.add(file_id)
deduped_reference_image_ids.append(file_id)
@@ -302,14 +294,14 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
f"Reference images requested but provider '{self.provider}' does not support image-editing context."
),
llm_facing_message=(
"This image provider does not support editing from previous image context. "
"This image provider does not support editing from existing images. "
"Try text-only generation, or switch to a provider/model that supports image edits."
),
)
max_reference_images = self.img_provider.max_reference_images
if max_reference_images > 0:
return deduped_reference_image_ids[-max_reference_images:]
return deduped_reference_image_ids[:max_reference_images]
return deduped_reference_image_ids
def _load_reference_images(
@@ -358,7 +350,7 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
def run(
self,
placement: Placement,
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
override_kwargs: None = None, # noqa: ARG002
**llm_kwargs: Any,
) -> ToolResponse:
if PROMPT_FIELD not in llm_kwargs:
@@ -373,7 +365,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
reference_image_file_ids = self._resolve_reference_image_file_ids(
llm_kwargs=llm_kwargs,
override_kwargs=override_kwargs,
)
reference_images = self._load_reference_images(reference_image_file_ids)

View File

@@ -1,4 +1,3 @@
import json
import traceback
from collections import defaultdict
from typing import Any
@@ -14,7 +13,6 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.interface import Tool
from onyx.tools.models import ChatFile
from onyx.tools.models import ChatMinimalTextMessage
from onyx.tools.models import ImageGenerationToolOverrideKwargs
from onyx.tools.models import OpenURLToolOverrideKwargs
from onyx.tools.models import ParallelToolCallResponse
from onyx.tools.models import PythonToolOverrideKwargs
@@ -24,9 +22,6 @@ from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolExecutionException
from onyx.tools.models import ToolResponse
from onyx.tools.models import WebSearchToolOverrideKwargs
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
@@ -110,63 +105,6 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
return merged_calls
def _extract_image_file_ids_from_tool_response_message(
message: str,
) -> list[str]:
try:
parsed_message = json.loads(message)
except json.JSONDecodeError:
return []
parsed_items: list[Any] = (
parsed_message if isinstance(parsed_message, list) else [parsed_message]
)
file_ids: list[str] = []
for item in parsed_items:
if not isinstance(item, dict):
continue
file_id = item.get("file_id")
if isinstance(file_id, str):
file_ids.append(file_id)
return file_ids
def _extract_recent_generated_image_file_ids(
message_history: list[ChatMessageSimple],
) -> list[str]:
tool_name_by_tool_call_id: dict[str, str] = {}
recent_image_file_ids: list[str] = []
seen_file_ids: set[str] = set()
for message in message_history:
if message.message_type == MessageType.ASSISTANT and message.tool_calls:
for tool_call in message.tool_calls:
tool_name_by_tool_call_id[tool_call.tool_call_id] = tool_call.tool_name
continue
if (
message.message_type != MessageType.TOOL_CALL_RESPONSE
or not message.tool_call_id
):
continue
tool_name = tool_name_by_tool_call_id.get(message.tool_call_id)
if tool_name != ImageGenerationTool.NAME:
continue
for file_id in _extract_image_file_ids_from_tool_response_message(
message.message
):
if file_id in seen_file_ids:
continue
seen_file_ids.add(file_id)
recent_image_file_ids.append(file_id)
return recent_image_file_ids
def _safe_run_single_tool(
tool: Tool,
tool_call: ToolCallKickoff,
@@ -386,9 +324,6 @@ def run_tool_calls(
url_to_citation: dict[str, int] = {
url: citation_num for citation_num, url in citation_mapping.items()
}
recent_generated_image_file_ids = _extract_recent_generated_image_file_ids(
message_history
)
# Prepare all tool calls with their override_kwargs
# Each tool gets a unique starting citation number to avoid conflicts when running in parallel
@@ -405,7 +340,6 @@ def run_tool_calls(
| WebSearchToolOverrideKwargs
| OpenURLToolOverrideKwargs
| PythonToolOverrideKwargs
| ImageGenerationToolOverrideKwargs
| MemoryToolOverrideKwargs
| None
) = None
@@ -454,10 +388,6 @@ def run_tool_calls(
override_kwargs = PythonToolOverrideKwargs(
chat_files=chat_files or [],
)
elif isinstance(tool, ImageGenerationTool):
override_kwargs = ImageGenerationToolOverrideKwargs(
recent_generated_image_file_ids=recent_generated_image_file_ids
)
elif isinstance(tool, MemoryTool):
override_kwargs = MemoryToolOverrideKwargs(
user_name=(

View File

@@ -38,38 +38,41 @@ class TestAddMemory:
def test_add_memory_creates_row(self, db_session: Session, test_user: User) -> None:
"""Verify that add_memory inserts a new Memory row."""
user_id = test_user.id
memory = add_memory(
memory_id = add_memory(
user_id=user_id,
memory_text="User prefers dark mode",
db_session=db_session,
)
assert memory.id is not None
assert memory.user_id == user_id
assert memory.memory_text == "User prefers dark mode"
assert memory_id is not None
# Verify it persists
fetched = db_session.get(Memory, memory.id)
fetched = db_session.get(Memory, memory_id)
assert fetched is not None
assert fetched.user_id == user_id
assert fetched.memory_text == "User prefers dark mode"
def test_add_multiple_memories(self, db_session: Session, test_user: User) -> None:
"""Verify that multiple memories can be added for the same user."""
user_id = test_user.id
m1 = add_memory(
m1_id = add_memory(
user_id=user_id,
memory_text="Favorite color is blue",
db_session=db_session,
)
m2 = add_memory(
m2_id = add_memory(
user_id=user_id,
memory_text="Works in engineering",
db_session=db_session,
)
assert m1.id != m2.id
assert m1.memory_text == "Favorite color is blue"
assert m2.memory_text == "Works in engineering"
assert m1_id != m2_id
fetched_m1 = db_session.get(Memory, m1_id)
fetched_m2 = db_session.get(Memory, m2_id)
assert fetched_m1 is not None
assert fetched_m2 is not None
assert fetched_m1.memory_text == "Favorite color is blue"
assert fetched_m2.memory_text == "Works in engineering"
class TestUpdateMemoryAtIndex:
@@ -82,15 +85,17 @@ class TestUpdateMemoryAtIndex:
add_memory(user_id=user_id, memory_text="Memory 1", db_session=db_session)
add_memory(user_id=user_id, memory_text="Memory 2", db_session=db_session)
updated = update_memory_at_index(
updated_id = update_memory_at_index(
user_id=user_id,
index=1,
new_text="Updated Memory 1",
db_session=db_session,
)
assert updated is not None
assert updated.memory_text == "Updated Memory 1"
assert updated_id is not None
fetched = db_session.get(Memory, updated_id)
assert fetched is not None
assert fetched.memory_text == "Updated Memory 1"
def test_update_memory_at_out_of_range_index(
self, db_session: Session, test_user: User
@@ -167,7 +172,7 @@ class TestMemoryCap:
assert len(rows_before) == MAX_MEMORIES_PER_USER
# Add one more — should evict the oldest
new_memory = add_memory(
new_memory_id = add_memory(
user_id=user_id,
memory_text="New memory after cap",
db_session=db_session,
@@ -181,7 +186,7 @@ class TestMemoryCap:
# Oldest ("Memory 0") should be gone; "Memory 1" is now the oldest
assert rows_after[0].memory_text == "Memory 1"
# Newest should be the one we just added
assert rows_after[-1].id == new_memory.id
assert rows_after[-1].id == new_memory_id
assert rows_after[-1].memory_text == "New memory after cap"
@@ -221,22 +226,26 @@ class TestGetMemoriesWithUserId:
user_id = test_user_no_memories.id
# Add a memory
memory = add_memory(
memory_id = add_memory(
user_id=user_id,
memory_text="Memory with use_memories off",
db_session=db_session,
)
assert memory.memory_text == "Memory with use_memories off"
fetched = db_session.get(Memory, memory_id)
assert fetched is not None
assert fetched.memory_text == "Memory with use_memories off"
# Update that memory
updated = update_memory_at_index(
updated_id = update_memory_at_index(
user_id=user_id,
index=0,
new_text="Updated memory with use_memories off",
db_session=db_session,
)
assert updated is not None
assert updated.memory_text == "Updated memory with use_memories off"
assert updated_id is not None
fetched_updated = db_session.get(Memory, updated_id)
assert fetched_updated is not None
assert fetched_updated.memory_text == "Updated memory with use_memories off"
# Verify get_memories returns the updated memory
context = get_memories(test_user_no_memories, db_session)

View File

@@ -9,6 +9,7 @@ from unittest.mock import patch
from ee.onyx.db.license import check_seat_availability
from ee.onyx.db.license import delete_license
from ee.onyx.db.license import get_license
from ee.onyx.db.license import get_used_seats
from ee.onyx.db.license import upsert_license
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicenseSource
@@ -214,3 +215,43 @@ class TestCheckSeatAvailabilityMultiTenant:
assert result.available is False
assert result.error_message is not None
mock_tenant_count.assert_called_once_with("tenant-abc")
class TestGetUsedSeatsAccountTypeFiltering:
"""Verify get_used_seats query excludes SERVICE_ACCOUNT but includes BOT."""
@patch("ee.onyx.db.license.MULTI_TENANT", False)
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
def test_excludes_service_accounts(self, mock_get_session: MagicMock) -> None:
"""SERVICE_ACCOUNT users should not count toward seats."""
mock_session = MagicMock()
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
mock_session.execute.return_value.scalar.return_value = 5
result = get_used_seats()
assert result == 5
# Inspect the compiled query to verify account_type filter
call_args = mock_session.execute.call_args
query = call_args[0][0]
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
assert "SERVICE_ACCOUNT" in compiled
# BOT should NOT be excluded
assert "BOT" not in compiled
@patch("ee.onyx.db.license.MULTI_TENANT", False)
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
def test_still_excludes_ext_perm_user(self, mock_get_session: MagicMock) -> None:
"""EXT_PERM_USER exclusion should still be present."""
mock_session = MagicMock()
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
mock_session.execute.return_value.scalar.return_value = 3
get_used_seats()
call_args = mock_session.execute.call_args
query = call_args[0][0]
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
assert "EXT_PERM_USER" in compiled

View File

@@ -301,7 +301,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -332,7 +331,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -363,7 +361,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -391,7 +388,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -423,7 +419,6 @@ class TestRunModels:
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -456,7 +451,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
@@ -497,7 +491,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
@@ -519,7 +512,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
@@ -542,7 +534,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
@@ -596,7 +587,6 @@ class TestRunModels:
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
@@ -653,7 +643,6 @@ class TestRunModels:
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle",
side_effect=lambda *_, **__: completion_called.set(),
@@ -706,7 +695,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
@@ -736,7 +724,6 @@ class TestRunModels:
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",

View File

@@ -0,0 +1,182 @@
from typing import Any
import pytest
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from onyx.connectors.google_utils.google_kv import get_auth_url
from onyx.connectors.google_utils.google_kv import get_google_app_cred
from onyx.connectors.google_utils.google_kv import get_service_account_key
from onyx.connectors.google_utils.google_kv import upsert_google_app_cred
from onyx.connectors.google_utils.google_kv import upsert_service_account_key
from onyx.server.documents.models import GoogleAppCredentials
from onyx.server.documents.models import GoogleAppWebCredentials
from onyx.server.documents.models import GoogleServiceAccountKey
def _make_app_creds() -> GoogleAppCredentials:
return GoogleAppCredentials(
web=GoogleAppWebCredentials(
client_id="client-id.apps.googleusercontent.com",
project_id="test-project",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
client_secret="secret",
redirect_uris=["https://example.com/callback"],
javascript_origins=["https://example.com"],
)
)
def _make_service_account_key() -> GoogleServiceAccountKey:
return GoogleServiceAccountKey(
type="service_account",
project_id="test-project",
private_key_id="private-key-id",
private_key="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
client_email="test@test-project.iam.gserviceaccount.com",
client_id="123",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test",
universe_domain="googleapis.com",
)
def test_upsert_google_app_cred_stores_dict(monkeypatch: Any) -> None:
stored: dict[str, Any] = {}
class _StubKvStore:
def store(self, key: str, value: object, encrypt: bool) -> None:
stored["key"] = key
stored["value"] = value
stored["encrypt"] = encrypt
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
upsert_google_app_cred(_make_app_creds(), DocumentSource.GOOGLE_DRIVE)
assert stored["key"] == KV_GOOGLE_DRIVE_CRED_KEY
assert stored["encrypt"] is True
assert isinstance(stored["value"], dict)
assert stored["value"]["web"]["client_id"] == "client-id.apps.googleusercontent.com"
def test_upsert_service_account_key_stores_dict(monkeypatch: Any) -> None:
stored: dict[str, Any] = {}
class _StubKvStore:
def store(self, key: str, value: object, encrypt: bool) -> None:
stored["key"] = key
stored["value"] = value
stored["encrypt"] = encrypt
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
upsert_service_account_key(_make_service_account_key(), DocumentSource.GOOGLE_DRIVE)
assert stored["key"] == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
assert stored["encrypt"] is True
assert isinstance(stored["value"], dict)
assert stored["value"]["project_id"] == "test-project"
@pytest.mark.parametrize("legacy_string", [False, True])
def test_get_google_app_cred_accepts_dict_and_legacy_string(
monkeypatch: Any, legacy_string: bool
) -> None:
payload: dict[str, Any] = _make_app_creds().model_dump(mode="json")
stored_value: object = (
payload if not legacy_string else _make_app_creds().model_dump_json()
)
class _StubKvStore:
def load(self, key: str) -> object:
assert key == KV_GOOGLE_DRIVE_CRED_KEY
return stored_value
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
creds = get_google_app_cred(DocumentSource.GOOGLE_DRIVE)
assert creds.web.client_id == "client-id.apps.googleusercontent.com"
@pytest.mark.parametrize("legacy_string", [False, True])
def test_get_service_account_key_accepts_dict_and_legacy_string(
monkeypatch: Any, legacy_string: bool
) -> None:
stored_value: object = (
_make_service_account_key().model_dump(mode="json")
if not legacy_string
else _make_service_account_key().model_dump_json()
)
class _StubKvStore:
def load(self, key: str) -> object:
assert key == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
return stored_value
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
key = get_service_account_key(DocumentSource.GOOGLE_DRIVE)
assert key.client_email == "test@test-project.iam.gserviceaccount.com"
@pytest.mark.parametrize("legacy_string", [False, True])
def test_get_auth_url_accepts_dict_and_legacy_string(
monkeypatch: Any, legacy_string: bool
) -> None:
payload = _make_app_creds().model_dump(mode="json")
stored_value: object = (
payload if not legacy_string else _make_app_creds().model_dump_json()
)
stored_state: dict[str, object] = {}
class _StubKvStore:
def load(self, key: str) -> object:
assert key == KV_GOOGLE_DRIVE_CRED_KEY
return stored_value
def store(self, key: str, value: object, encrypt: bool) -> None:
stored_state["key"] = key
stored_state["value"] = value
stored_state["encrypt"] = encrypt
class _StubFlow:
def authorization_url(self, prompt: str) -> tuple[str, None]:
assert prompt == "consent"
return "https://accounts.google.com/o/oauth2/auth?state=test-state", None
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
)
def _from_client_config(
_app_config: object, *, scopes: object, redirect_uri: object
) -> _StubFlow:
del scopes, redirect_uri
return _StubFlow()
monkeypatch.setattr(
"onyx.connectors.google_utils.google_kv.InstalledAppFlow.from_client_config",
_from_client_config,
)
auth_url = get_auth_url(42, DocumentSource.GOOGLE_DRIVE)
assert auth_url.startswith("https://accounts.google.com")
assert stored_state["value"] == {"value": "test-state"}
assert stored_state["encrypt"] is True

View File

@@ -6,6 +6,7 @@ import requests
from jira import JIRA
from jira.resources import Issue
from onyx.connectors.jira.connector import _JIRA_BULK_FETCH_LIMIT
from onyx.connectors.jira.connector import bulk_fetch_issues
@@ -145,3 +146,29 @@ def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
def test_bulk_fetch_respects_api_batch_limit() -> None:
"""Requests to the bulkfetch endpoint never exceed _JIRA_BULK_FETCH_LIMIT IDs."""
client = _mock_jira_client()
total_issues = _JIRA_BULK_FETCH_LIMIT * 3 + 7
all_ids = [str(i) for i in range(total_issues)]
batch_sizes: list[int] = []
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
batch_sizes.append(len(ids))
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
result = bulk_fetch_issues(client, all_ids)
assert len(result) == total_issues
# keeping this hardcoded because it's the documented limit
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
assert all(size <= 100 for size in batch_sizes)
assert len(batch_sizes) == 4

View File

@@ -0,0 +1,67 @@
"""Tests for _build_thread_text function."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.context.search.federated.slack_search import _build_thread_text
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
return {"user": user, "text": text, "ts": ts}
class TestBuildThreadText:
"""Verify _build_thread_text includes full thread replies up to cap."""
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
"""All replies within cap are included in output."""
mock_profiles.return_value = {}
messages = [
_make_msg("U1", "parent msg", "1000.0"),
_make_msg("U2", "reply 1", "1001.0"),
_make_msg("U3", "reply 2", "1002.0"),
_make_msg("U4", "reply 3", "1003.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "parent msg" in result
assert "reply 1" in result
assert "reply 2" in result
assert "reply 3" in result
assert "..." not in result
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
"""Single message (no replies) returns just the parent text."""
mock_profiles.return_value = {}
messages = [_make_msg("U1", "just a message", "1000.0")]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "just a message" in result
assert "Replies:" not in result
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
"""Thread parent message is always the first line of output."""
mock_profiles.return_value = {}
messages = [
_make_msg("U1", "I am the parent", "1000.0"),
_make_msg("U2", "I am a reply", "1001.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
parent_pos = result.index("I am the parent")
reply_pos = result.index("I am a reply")
assert parent_pos < reply_pos
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
"""User IDs in thread text are replaced with display names."""
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
messages = [
_make_msg("U1", "hello", "1000.0"),
_make_msg("U2", "world", "1001.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "Alice" in result
assert "Bob" in result
assert "<@U1>" not in result
assert "<@U2>" not in result

View File

@@ -0,0 +1,108 @@
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
class TestExtractSlackMessageUrls:
"""Verify URL parsing extracts channel_id and timestamp correctly."""
def test_standard_url(self) -> None:
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
results = extract_slack_message_urls(query)
assert len(results) == 1
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
def test_multiple_urls(self) -> None:
query = (
"compare https://co.slack.com/archives/C111/p1234567890123456 "
"and https://co.slack.com/archives/C222/p9876543210987654"
)
results = extract_slack_message_urls(query)
assert len(results) == 2
assert results[0] == ("C111", "1234567890.123456")
assert results[1] == ("C222", "9876543210.987654")
def test_no_urls(self) -> None:
query = "what happened in #general last week?"
results = extract_slack_message_urls(query)
assert len(results) == 0
def test_non_slack_url_ignored(self) -> None:
query = "check https://google.com/archives/C111/p1234567890123456"
results = extract_slack_message_urls(query)
assert len(results) == 0
def test_timestamp_conversion(self) -> None:
"""p prefix removed, dot inserted after 10th digit."""
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
results = extract_slack_message_urls(query)
channel_id, ts = results[0]
assert channel_id == "CABC123"
assert ts == "1775491616.524769"
assert not ts.startswith("p")
assert "." in ts
class TestFetchThreadFromUrl:
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
@patch("onyx.context.search.federated.slack_search._build_thread_text")
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_successful_fetch(
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
) -> None:
mock_client = MagicMock()
mock_webclient_cls.return_value = mock_client
# Mock conversations_replies
mock_response = MagicMock()
mock_response.get.return_value = [
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
]
mock_client.conversations_replies.return_value = mock_response
# Mock channel info
mock_ch_response = MagicMock()
mock_ch_response.get.return_value = {"name": "general"}
mock_client.conversations_info.return_value = mock_ch_response
mock_build_thread.return_value = (
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
)
fetch = DirectThreadFetch(
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
)
result = _fetch_thread_from_url(fetch, "xoxp-token")
assert len(result.messages) == 1
msg = result.messages[0]
assert msg.channel_id == "C097NBWMY8Y"
assert msg.thread_id is None # Prevents double-enrichment
assert msg.slack_score == 100000.0
assert "parent" in msg.text
mock_client.conversations_replies.assert_called_once_with(
channel="C097NBWMY8Y", ts="1775491616.524769"
)
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
from slack_sdk.errors import SlackApiError
mock_client = MagicMock()
mock_webclient_cls.return_value = mock_client
mock_client.conversations_replies.side_effect = SlackApiError(
message="channel_not_found",
response=MagicMock(status_code=404),
)
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
result = _fetch_thread_from_url(fetch, "xoxp-token")
assert len(result.messages) == 0

View File

@@ -29,6 +29,7 @@ from onyx.llm.utils import get_max_input_tokens
VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG = [
"claude-opus-4-5@20251101",
"claude-opus-4-6",
"claude-opus-4-7",
]

View File

@@ -505,6 +505,7 @@ class TestGetLMStudioAvailableModels:
mock_session = MagicMock()
mock_provider = MagicMock()
mock_provider.api_base = "http://localhost:1234"
mock_provider.custom_config = {"LM_STUDIO_API_KEY": "stored-secret"}
response = {

View File

@@ -100,6 +100,39 @@ class TestGenerateOllamaDisplayName:
result = generate_ollama_display_name("llama3.3:70b")
assert "3.3" in result or "3 3" in result # Either format is acceptable
def test_non_size_tag_shown(self) -> None:
"""Test that non-size tags like 'e4b' are included in the display name."""
result = generate_ollama_display_name("gemma4:e4b")
assert "Gemma" in result
assert "4" in result
assert "E4B" in result
def test_size_with_cloud_modifier(self) -> None:
"""Test size tag with cloud modifier."""
result = generate_ollama_display_name("deepseek-v3.1:671b-cloud")
assert "DeepSeek" in result
assert "671B" in result
assert "Cloud" in result
def test_size_with_multiple_modifiers(self) -> None:
"""Test size tag with multiple modifiers."""
result = generate_ollama_display_name("qwen3-vl:235b-instruct-cloud")
assert "Qwen" in result
assert "235B" in result
assert "Instruct" in result
assert "Cloud" in result
def test_quantization_tag_shown(self) -> None:
"""Test that quantization tags are included in the display name."""
result = generate_ollama_display_name("llama3:q4_0")
assert "Llama" in result
assert "Q4_0" in result
def test_cloud_only_tag(self) -> None:
"""Test standalone cloud tag."""
result = generate_ollama_display_name("glm-4.6:cloud")
assert "CLOUD" in result
class TestStripOpenrouterVendorPrefix:
"""Tests for OpenRouter vendor prefix stripping."""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
@@ -9,7 +10,9 @@ from uuid import uuid4
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.server.scim.api import _check_seat_availability
from ee.onyx.server.scim.api import _scim_name_to_str
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
from ee.onyx.server.scim.api import create_user
from ee.onyx.server.scim.api import delete_user
from ee.onyx.server.scim.api import get_user
@@ -741,3 +744,80 @@ class TestEmailCasePreservation:
resource = parse_scim_user(result)
assert resource.userName == "Alice@Example.COM"
assert resource.emails[0].value == "Alice@Example.COM"
class TestSeatLock:
"""Tests for the advisory lock in _check_seat_availability."""
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
def test_acquires_advisory_lock_before_checking(
self,
_mock_tenant: MagicMock,
mock_dal: MagicMock,
) -> None:
"""The advisory lock must be acquired before the seat check runs."""
call_order: list[str] = []
def track_execute(stmt: Any, _params: Any = None) -> None:
if "pg_advisory_xact_lock" in str(stmt):
call_order.append("lock")
mock_dal.session.execute.side_effect = track_execute
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
) as mock_fetch:
mock_result = MagicMock()
mock_result.available = True
mock_fn = MagicMock(return_value=mock_result)
mock_fetch.return_value = mock_fn
def track_check(*_args: Any, **_kwargs: Any) -> Any:
call_order.append("check")
return mock_result
mock_fn.side_effect = track_check
_check_seat_availability(mock_dal)
assert call_order == ["lock", "check"]
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
def test_lock_uses_tenant_scoped_key(
self,
_mock_tenant: MagicMock,
mock_dal: MagicMock,
) -> None:
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
mock_result = MagicMock()
mock_result.available = True
mock_check = MagicMock(return_value=mock_result)
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
return_value=mock_check,
):
_check_seat_availability(mock_dal)
mock_dal.session.execute.assert_called_once()
params = mock_dal.session.execute.call_args[0][1]
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
"""Lock id must be deterministic and differ across tenants."""
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
def test_no_lock_when_ee_absent(
self,
mock_dal: MagicMock,
) -> None:
"""No advisory lock should be acquired when the EE check is absent."""
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
return_value=None,
):
result = _check_seat_availability(mock_dal)
assert result is None
mock_dal.session.execute.assert_not_called()

View File

@@ -95,9 +95,9 @@ class TestForceAddSearchToolGuard:
without a vector DB."""
import inspect
from onyx.tools.tool_constructor import construct_tools
from onyx.tools.tool_constructor import _construct_tools_impl
source = inspect.getsource(construct_tools)
source = inspect.getsource(_construct_tools_impl)
assert (
"DISABLE_VECTOR_DB" in source
), "construct_tools should reference DISABLE_VECTOR_DB to suppress force-adding SearchTool"

View File

@@ -0,0 +1,110 @@
"""Tests for ``ImageGenerationTool._resolve_reference_image_file_ids``.
The resolver turns the LLM's ``reference_image_file_ids`` argument into a
cleaned list of file IDs to hand to ``_load_reference_images``. It trusts
the LLM's picks — the LLM can only see file IDs that actually appear in
the conversation (via ``[attached image — file_id: <id>]`` tags on user
messages and the JSON returned by prior generate_image calls), so we
don't re-validate against an allow-list in the tool itself.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.tools.models import ToolCallException
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
REFERENCE_IMAGE_FILE_IDS_FIELD,
)
def _make_tool(
supports_reference_images: bool = True,
max_reference_images: int = 16,
) -> ImageGenerationTool:
"""Construct a tool with a mock provider so no credentials/network are needed."""
with patch(
"onyx.tools.tool_implementations.images.image_generation_tool.get_image_generation_provider"
) as mock_get_provider:
mock_provider = MagicMock()
mock_provider.supports_reference_images = supports_reference_images
mock_provider.max_reference_images = max_reference_images
mock_get_provider.return_value = mock_provider
return ImageGenerationTool(
image_generation_credentials=MagicMock(),
tool_id=1,
emitter=MagicMock(),
model="gpt-image-1",
provider="openai",
)
class TestResolveReferenceImageFileIds:
def test_unset_returns_empty_plain_generation(self) -> None:
tool = _make_tool()
assert tool._resolve_reference_image_file_ids(llm_kwargs={}) == []
def test_empty_list_is_treated_like_unset(self) -> None:
tool = _make_tool()
result = tool._resolve_reference_image_file_ids(
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: []},
)
assert result == []
def test_passes_llm_supplied_ids_through(self) -> None:
tool = _make_tool()
result = tool._resolve_reference_image_file_ids(
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["upload-1", "gen-1"]},
)
# Order preserved — first entry is the primary edit source.
assert result == ["upload-1", "gen-1"]
def test_invalid_shape_raises(self) -> None:
tool = _make_tool()
with pytest.raises(ToolCallException):
tool._resolve_reference_image_file_ids(
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: "not-a-list"},
)
def test_non_string_element_raises(self) -> None:
tool = _make_tool()
with pytest.raises(ToolCallException):
tool._resolve_reference_image_file_ids(
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["ok", 123]},
)
def test_deduplicates_preserving_first_occurrence(self) -> None:
tool = _make_tool()
result = tool._resolve_reference_image_file_ids(
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1", "gen-2", "gen-1"]},
)
assert result == ["gen-1", "gen-2"]
def test_strips_whitespace_and_skips_empty_strings(self) -> None:
tool = _make_tool()
result = tool._resolve_reference_image_file_ids(
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: [" gen-1 ", "", " "]},
)
assert result == ["gen-1"]
def test_provider_without_reference_support_raises(self) -> None:
tool = _make_tool(supports_reference_images=False)
with pytest.raises(ToolCallException):
tool._resolve_reference_image_file_ids(
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1"]},
)
def test_truncates_to_provider_max_preserving_head(self) -> None:
"""When the LLM lists more images than the provider allows, keep the
HEAD of the list (the primary edit source + earliest extras) rather
than the tail, since the LLM put the most important one first."""
tool = _make_tool(max_reference_images=2)
result = tool._resolve_reference_image_file_ids(
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["a", "b", "c", "d"]},
)
assert result == ["a", "b"]

View File

@@ -1,10 +1,5 @@
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import MessageType
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_runner import _extract_image_file_ids_from_tool_response_message
from onyx.tools.tool_runner import _extract_recent_generated_image_file_ids
from onyx.tools.tool_runner import _merge_tool_calls
@@ -312,62 +307,3 @@ class TestMergeToolCalls:
assert len(result) == 1
# String should be converted to list item
assert result[0].tool_args["queries"] == ["single_query", "q2"]
class TestImageHistoryExtraction:
def test_extracts_image_file_ids_from_json_response(self) -> None:
msg = '[{"file_id":"img-1","revised_prompt":"v1"},{"file_id":"img-2","revised_prompt":"v2"}]'
assert _extract_image_file_ids_from_tool_response_message(msg) == [
"img-1",
"img-2",
]
def test_extracts_recent_generated_image_ids_from_history(self) -> None:
history = [
ChatMessageSimple(
message="",
token_count=1,
message_type=MessageType.ASSISTANT,
tool_calls=[
ToolCallSimple(
tool_call_id="call_1",
tool_name="generate_image",
tool_arguments={"prompt": "test"},
token_count=1,
)
],
),
ChatMessageSimple(
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
token_count=1,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id="call_1",
),
]
assert _extract_recent_generated_image_file_ids(history) == ["img-1"]
def test_ignores_non_image_tool_responses(self) -> None:
history = [
ChatMessageSimple(
message="",
token_count=1,
message_type=MessageType.ASSISTANT,
tool_calls=[
ToolCallSimple(
tool_call_id="call_1",
tool_name="web_search",
tool_arguments={"queries": ["q"]},
token_count=1,
)
],
),
ChatMessageSimple(
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
token_count=1,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id="call_1",
),
]
assert _extract_recent_generated_image_file_ids(history) == []

View File

@@ -15,6 +15,7 @@ type InteractiveStatefulVariant =
| "select-heavy"
| "select-card"
| "select-tinted"
| "select-input"
| "select-filter"
| "sidebar-heavy"
| "sidebar-light";
@@ -35,6 +36,7 @@ interface InteractiveStatefulProps
* - `"select-heavy"` — tinted selected background (for list rows, model pickers)
* - `"select-card"` — like select-heavy but filled state has a visible background (for cards/larger surfaces)
* - `"select-tinted"` — like select-heavy but with a tinted rest background
* - `"select-input"` — rests at neutral-00 (matches input bar), hover/open shows neutral-03 + border-01
* - `"select-filter"` — like select-tinted for empty/filled; selected state uses inverted tint backgrounds and inverted text (for filter buttons)
* - `"sidebar-heavy"` — sidebar navigation items: muted when unselected (text-03/text-02), bold when selected (text-04/text-03)
* - `"sidebar-light"` — sidebar navigation items: uniformly muted across all states (text-02/text-02)

View File

@@ -350,6 +350,41 @@
--interactive-foreground-icon: var(--text-01);
}
/* ---------------------------------------------------------------------------
Select-Input — Empty
Matches input bar background at rest, tints on hover/open.
--------------------------------------------------------------------------- */
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"] {
@apply bg-background-neutral-00;
--interactive-foreground: var(--text-04);
--interactive-foreground-icon: var(--text-03);
}
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"]:hover:not(
[data-disabled]
),
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-interaction="hover"]:not(
[data-disabled]
) {
@apply bg-background-neutral-03;
--interactive-foreground: var(--text-04);
--interactive-foreground-icon: var(--text-03);
}
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"]:active:not(
[data-disabled]
),
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-interaction="active"]:not(
[data-disabled]
) {
@apply bg-background-neutral-03;
--interactive-foreground: var(--text-05);
--interactive-foreground-icon: var(--text-05);
}
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-disabled] {
@apply bg-transparent;
--interactive-foreground: var(--text-01);
--interactive-foreground-icon: var(--text-01);
}
/* ---------------------------------------------------------------------------
Select-Tinted — Filled
--------------------------------------------------------------------------- */

16
web/package-lock.json generated
View File

@@ -47,6 +47,7 @@
"clsx": "^2.1.1",
"cmdk": "^1.0.0",
"cookies-next": "^5.1.0",
"copy-to-clipboard": "^3.3.3",
"date-fns": "^3.6.0",
"docx-preview": "^0.3.7",
"favicon-fetch": "^1.0.0",
@@ -8843,6 +8844,15 @@
"react": ">= 16.8.0"
}
},
"node_modules/copy-to-clipboard": {
"version": "3.3.3",
"resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz",
"integrity": "sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==",
"license": "MIT",
"dependencies": {
"toggle-selection": "^1.0.6"
}
},
"node_modules/core-js": {
"version": "3.46.0",
"hasInstallScript": true,
@@ -17426,6 +17436,12 @@
"node": ">=8.0"
}
},
"node_modules/toggle-selection": {
"version": "1.0.6",
"resolved": "https://registry.npmjs.org/toggle-selection/-/toggle-selection-1.0.6.tgz",
"integrity": "sha512-BiZS+C1OS8g/q2RRbJmy59xpyghNBqrr6k5L/uKBGRsTfxmu3ffiRnd8mlGPUVayg8pvfi5urfnu8TU7DVOkLQ==",
"license": "MIT"
},
"node_modules/toposort": {
"version": "2.0.2",
"license": "MIT"

View File

@@ -65,6 +65,7 @@
"clsx": "^2.1.1",
"cmdk": "^1.0.0",
"cookies-next": "^5.1.0",
"copy-to-clipboard": "^3.3.3",
"date-fns": "^3.6.0",
"docx-preview": "^0.3.7",
"favicon-fetch": "^1.0.0",

View File

@@ -17,6 +17,7 @@ import DocumentSetCard from "@/sections/cards/DocumentSetCard";
import CollapsibleSection from "@/app/admin/agents/CollapsibleSection";
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
import { StandardAnswerCategoryDropdownField } from "@/components/standardAnswers/StandardAnswerCategoryDropdown";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import { RadioGroup } from "@/components/ui/radio-group";
import { RadioGroupItemField } from "@/components/ui/RadioGroupItemField";
import { AlertCircle } from "lucide-react";
@@ -126,6 +127,24 @@ export function SlackChannelConfigFormFields({
return documentSets.filter((ds) => !documentSetContainsSync(ds));
}, [documentSets]);
const searchAgentOptions = useMemo(
() =>
availableAgents.map((persona) => ({
label: persona.name,
value: String(persona.id),
})),
[availableAgents]
);
const nonSearchAgentOptions = useMemo(
() =>
nonSearchAgents.map((persona) => ({
label: persona.name,
value: String(persona.id),
})),
[nonSearchAgents]
);
useEffect(() => {
const invalidSelected = values.document_sets.filter((dsId: number) =>
unselectableSets.some((us) => us.id === dsId)
@@ -355,12 +374,14 @@ export function SlackChannelConfigFormFields({
</>
</SubLabel>
<SelectorFormField
name="persona_id"
options={availableAgents.map((persona) => ({
name: persona.name,
value: persona.id,
}))}
<InputComboBox
placeholder="Search for an agent..."
value={String(values.persona_id ?? "")}
onValueChange={(val) =>
setFieldValue("persona_id", val ? Number(val) : null)
}
options={searchAgentOptions}
strict
/>
{viewSyncEnabledAgents && syncEnabledAgents.length > 0 && (
<div className="mt-4">
@@ -419,12 +440,14 @@ export function SlackChannelConfigFormFields({
</>
</SubLabel>
<SelectorFormField
name="persona_id"
options={nonSearchAgents.map((persona) => ({
name: persona.name,
value: persona.id,
}))}
<InputComboBox
placeholder="Search for an agent..."
value={String(values.persona_id ?? "")}
onValueChange={(val) =>
setFieldValue("persona_id", val ? Number(val) : null)
}
options={nonSearchAgentOptions}
strict
/>
</div>
)}

View File

@@ -73,7 +73,10 @@ export const MemoizedAnchor = memo(
: undefined;
if (!associatedDoc && !associatedSubQuestion) {
return <>{children}</>;
// Citation not resolved yet (data still streaming) — hide the
// raw [[N]](url) link entirely. It will render as a chip once
// the citation/document data arrives.
return <></>;
}
let icon: React.ReactNode = null;

View File

@@ -44,6 +44,8 @@ export interface MultiModelPanelProps {
errorStackTrace?: string | null;
/** Additional error details */
errorDetails?: Record<string, any> | null;
/** Whether any model is still streaming — disables preferred selection */
isGenerating?: boolean;
}
/**
@@ -73,19 +75,24 @@ export default function MultiModelPanel({
isRetryable,
errorStackTrace,
errorDetails,
isGenerating,
}: MultiModelPanelProps) {
const ModelIcon = getModelIcon(provider, modelName);
const canSelect = !isHidden && !isPreferred && !isGenerating;
const handlePanelClick = useCallback(() => {
if (!isHidden && !isPreferred) onSelect();
}, [isHidden, isPreferred, onSelect]);
if (canSelect) onSelect();
}, [canSelect, onSelect]);
const header = (
<div
className={cn(
"rounded-12",
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
"rounded-12 transition-colors",
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00",
canSelect && "cursor-pointer hover:bg-background-tint-02"
)}
onClick={handlePanelClick}
>
<ContentAction
sizePreset="main-ui"
@@ -140,13 +147,7 @@ export default function MultiModelPanel({
}
return (
<div
className={cn(
"flex flex-col gap-3 min-w-0 rounded-16 transition-colors",
!isPreferred && "cursor-pointer hover:bg-background-tint-02"
)}
onClick={handlePanelClick}
>
<div className="flex flex-col gap-3 min-w-0 rounded-16">
{header}
{errorMessage ? (
<div className="p-4">
@@ -163,6 +164,7 @@ export default function MultiModelPanel({
<AgentMessage
{...agentMessageProps}
hideFooter={isNonPreferredInSelection}
disableTTS
/>
</div>
)}

View File

@@ -1,6 +1,13 @@
"use client";
import { useState, useCallback, useMemo, useEffect, useRef } from "react";
import {
useState,
useCallback,
useMemo,
useEffect,
useLayoutEffect,
useRef,
} from "react";
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
import { Message } from "@/app/app/interfaces";
import { LlmManager } from "@/lib/hooks";
@@ -110,11 +117,27 @@ export default function MultiModelResponseView({
// Refs to each panel wrapper for height animation on deselect
const panelElsRef = useRef<Map<number, HTMLDivElement>>(new Map());
// Tracks which non-preferred panels overflow the preferred height cap
// Tracks which non-preferred panels overflow the preferred height cap.
// Measured via useLayoutEffect after maxHeight is applied to the DOM —
// ref callbacks fire before layout and can't reliably detect overflow.
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
new Set()
);
useLayoutEffect(() => {
if (preferredPanelHeight == null || preferredIndex === null) return;
const next = new Set<number>();
panelElsRef.current.forEach((el, idx) => {
if (idx === preferredIndex || hiddenPanels.has(idx)) return;
if (el.scrollHeight > el.clientHeight) next.add(idx);
});
setOverflowingPanels((prev) => {
if (prev.size === next.size && Array.from(prev).every((v) => next.has(v)))
return prev;
return next;
});
}, [preferredPanelHeight, preferredIndex, hiddenPanels, responses]);
const preferredPanelRef = useCallback((el: HTMLDivElement | null) => {
if (preferredRoRef.current) {
preferredRoRef.current.disconnect();
@@ -210,8 +233,10 @@ export default function MultiModelResponseView({
const response = responses.find((r) => r.modelIndex === modelIndex);
if (!response) return;
// Persist preferred response to backend + update local tree so the
// input bar unblocks (awaitingPreferredSelection clears).
// Persist preferred response + sync `latestChildNodeId`. Backend's
// `set_preferred_response` updates `latest_child_message_id`; if the
// frontend chain walk disagrees, the next follow-up fails with
// "not on the latest mainline".
if (parentMessage?.messageId && response.messageId && currentSessionId) {
setPreferredResponse(parentMessage.messageId, response.messageId).catch(
(err) => console.error("Failed to persist preferred response:", err)
@@ -227,6 +252,7 @@ export default function MultiModelResponseView({
updated.set(parentMessage.nodeId, {
...userMsg,
preferredResponseId: response.messageId,
latestChildNodeId: response.nodeId,
});
updateSessionMessageTree(currentSessionId, updated);
}
@@ -413,6 +439,7 @@ export default function MultiModelResponseView({
isRetryable: response.isRetryable,
errorStackTrace: response.errorStackTrace,
errorDetails: response.errorDetails,
isGenerating,
}),
[
preferredIndex,
@@ -426,6 +453,7 @@ export default function MultiModelResponseView({
onMessageSelection,
onRegenerate,
parentMessage,
isGenerating,
]
);
@@ -512,17 +540,6 @@ export default function MultiModelResponseView({
panelElsRef.current.delete(r.modelIndex);
}
if (isPref) preferredPanelRef(el);
if (capped && el) {
const doesOverflow = el.scrollHeight > el.clientHeight;
setOverflowingPanels((prev) => {
const had = prev.has(r.modelIndex);
if (doesOverflow === had) return prev;
const next = new Set(prev);
if (doesOverflow) next.add(r.modelIndex);
else next.delete(r.modelIndex);
return next;
});
}
}}
style={{
width: `${selectionEntered ? finalW : startW}px`,
@@ -533,21 +550,19 @@ export default function MultiModelResponseView({
: "none",
maxHeight: capped ? preferredPanelHeight : undefined,
overflow: capped ? "hidden" : undefined,
position: capped ? "relative" : undefined,
...(overflows
? {
maskImage:
"linear-gradient(to bottom, black calc(100% - 6rem), transparent 100%)",
WebkitMaskImage:
"linear-gradient(to bottom, black calc(100% - 6rem), transparent 100%)",
}
: {}),
}}
>
<div className={cn(isNonPref && "opacity-50")}>
<MultiModelPanel {...buildPanelProps(r, isNonPref)} />
</div>
{overflows && (
<div
className="absolute inset-x-0 bottom-0 h-24 pointer-events-none"
style={{
background:
"linear-gradient(to top, var(--background-tint-01) 0%, transparent 100%)",
}}
/>
)}
</div>
);
})}

View File

@@ -1,3 +1,25 @@
/* Map Tailwind Typography prose variables to the project's color tokens.
These auto-switch for dark mode via colors.css — no dark: modifier needed.
Note: text-05 = highest contrast, text-01 = lowest. */
.prose-onyx {
--tw-prose-body: var(--text-05);
--tw-prose-headings: var(--text-05);
--tw-prose-lead: var(--text-04);
--tw-prose-links: var(--action-link-05);
--tw-prose-bold: var(--text-05);
--tw-prose-counters: var(--text-03);
--tw-prose-bullets: var(--text-03);
--tw-prose-hr: var(--border-02);
--tw-prose-quotes: var(--text-04);
--tw-prose-quote-borders: var(--border-02);
--tw-prose-captions: var(--text-03);
--tw-prose-code: var(--text-05);
--tw-prose-pre-code: var(--text-04);
--tw-prose-pre-bg: var(--background-code-01);
--tw-prose-th-borders: var(--border-02);
--tw-prose-td-borders: var(--border-01);
}
/* Light mode syntax highlighting (Atom One Light) */
.hljs {
color: #383a42 !important;
@@ -236,23 +258,102 @@ pre[class*="language-"] {
scrollbar-color: #4b5563 #1f2937;
}
/* Card wrapper — holds the background, border-radius, padding, and fade overlay.
Does NOT scroll — the inner .markdown-table-breakout handles that. */
.markdown-table-card {
position: relative;
background: var(--background-neutral-01);
border-radius: 0.5rem;
padding: 0.5rem 0;
}
/*
* Table breakout container - allows tables to extend beyond their parent's
* constrained width to use the full container query width (100cqw).
*
* Requires an ancestor element with `container-type: inline-size` (@container in Tailwind).
*
* How the math works:
* - width: 100cqw → expand to full container query width
* - marginLeft: calc((100% - 100cqw) / 2) → negative margin pulls element left
* (100% is parent width, 100cqw is larger, so result is negative)
* - paddingLeft/Right: calc((100cqw - 100%) / 2) → padding keeps content aligned
* with original position while allowing scroll area to extend
* Scrollable table container — sits inside the card.
*/
.markdown-table-breakout {
overflow-x: auto;
width: 100cqw;
margin-left: calc((100% - 100cqw) / 2);
padding-left: calc((100cqw - 100%) / 2);
padding-right: calc((100cqw - 100%) / 2);
/* Always reserve scrollbar height so hover doesn't shift content.
Thumb is transparent by default, revealed on hover. */
scrollbar-width: thin; /* Firefox — always shows track */
scrollbar-color: transparent transparent; /* invisible thumb + track */
}
.markdown-table-breakout::-webkit-scrollbar {
height: 6px;
}
.markdown-table-breakout::-webkit-scrollbar-track {
background: transparent;
}
.markdown-table-breakout::-webkit-scrollbar-thumb {
background: transparent;
border-radius: 3px;
}
.markdown-table-breakout:hover {
scrollbar-color: var(--border-03) transparent; /* Firefox — reveal thumb */
}
.markdown-table-breakout:hover::-webkit-scrollbar-thumb {
background: var(--border-03);
}
/* Fade the right edge via an ::after overlay on the non-scrolling card.
Stays pinned while table scrolls; doesn't affect the sticky column. */
.markdown-table-card::after {
content: "";
position: absolute;
top: 0;
right: 0;
bottom: 0;
width: 2rem;
pointer-events: none;
z-index: 2;
background: linear-gradient(
to right,
transparent,
var(--background-neutral-01)
);
border-radius: 0 0.5rem 0.5rem 0;
opacity: 0;
transition: opacity 0.15s;
}
.markdown-table-card[data-overflows="true"]::after {
opacity: 1;
}
/* Sticky first column — inherits the container's background so it
matches regardless of theme or custom wallpaper. */
.markdown-table-breakout th:first-child,
.markdown-table-breakout td:first-child {
position: sticky;
left: 0;
z-index: 1;
padding-left: 0.75rem;
background: var(--background-neutral-01);
}
.markdown-table-breakout th:last-child,
.markdown-table-breakout td:last-child {
padding-right: 0.75rem;
}
/* Shadow on sticky column when scrolled. Uses an ::after pseudo-element
so it isn't clipped by the overflow container or the mask-image fade. */
.markdown-table-breakout th:first-child::after,
.markdown-table-breakout td:first-child::after {
content: "";
position: absolute;
top: 0;
right: -6px;
bottom: 0;
width: 6px;
pointer-events: none;
opacity: 0;
transition: opacity 0.15s;
box-shadow: inset 6px 0 8px -4px var(--alpha-grey-100-25);
}
.dark .markdown-table-breakout th:first-child::after,
.dark .markdown-table-breakout td:first-child::after {
box-shadow: inset 6px 0 8px -4px var(--alpha-grey-100-60);
}
.markdown-table-breakout[data-scrolled="true"] th:first-child::after,
.markdown-table-breakout[data-scrolled="true"] td:first-child::after {
opacity: 1;
}

View File

@@ -51,6 +51,8 @@ export interface AgentMessageProps {
processingDurationSeconds?: number;
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
hideFooter?: boolean;
/** Skip TTS streaming (used in multi-model where voice doesn't apply) */
disableTTS?: boolean;
}
// TODO: Consider more robust comparisons:
@@ -99,6 +101,7 @@ const AgentMessage = React.memo(function AgentMessage({
parentMessage,
processingDurationSeconds,
hideFooter,
disableTTS,
}: AgentMessageProps) {
const markdownRef = useRef<HTMLDivElement>(null);
const finalAnswerRef = useRef<HTMLDivElement>(null);
@@ -133,32 +136,49 @@ const AgentMessage = React.memo(function AgentMessage({
finalAnswerComing
);
// Memoize merged citations separately to avoid creating new object when neither source changed
// Merge streaming citation/document data with chatState props.
// NOTE: citationMap and documentMap from usePacketProcessor are mutated in
// place (same object reference), so we use citations.length / documentMap.size
// as change-detection proxies to bust the memo cache when new data arrives.
const mergedCitations = useMemo(
() => ({
...chatState.citations,
...citationMap,
}),
[chatState.citations, citationMap]
// eslint-disable-next-line react-hooks/exhaustive-deps
[chatState.citations, citationMap, citations.length]
);
// Create a chatState that uses streaming citations for immediate rendering
// This merges the prop citations with streaming citations, preferring streaming ones
// Memoized with granular dependencies to prevent cascading re-renders
// Merge streaming documentMap into chatState.docs so inline citation chips
// can resolve [1] → document even when chatState.docs is empty (multi-model).
const mergedDocs = useMemo(() => {
const propDocs = chatState.docs ?? [];
if (documentMap.size === 0) return propDocs;
const seen = new Set(propDocs.map((d) => d.document_id));
const extras = Array.from(documentMap.values()).filter(
(d) => !seen.has(d.document_id)
);
return extras.length > 0 ? [...propDocs, ...extras] : propDocs;
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [chatState.docs, documentMap, documentMap.size]);
// Create a chatState that uses streaming citations and documents for immediate rendering.
// Memoized with granular dependencies to prevent cascading re-renders.
// Note: chatState object is recreated upstream on every render, so we depend on
// individual fields instead of the whole object for proper memoization
// individual fields instead of the whole object for proper memoization.
const effectiveChatState = useMemo<FullChatState>(
() => ({
...chatState,
citations: mergedCitations,
docs: mergedDocs,
}),
[
chatState.agent,
chatState.docs,
chatState.setPresentingDocument,
chatState.overriddenModel,
chatState.researchType,
mergedCitations,
mergedDocs,
]
);
@@ -202,6 +222,9 @@ const AgentMessage = React.memo(function AgentMessage({
// Skip if we've already finished TTS for this message
if (ttsCompletedRef.current) return;
// Multi-model: skip TTS entirely
if (disableTTS) return;
// If user cancelled generation, do not send more text to TTS.
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
ttsCompletedRef.current = true;
@@ -305,7 +328,7 @@ const AgentMessage = React.memo(function AgentMessage({
onRenderComplete();
}
}}
animate={false}
animate={!stopPacketSeen}
stopPacketSeen={stopPacketSeen}
stopReason={stopReason}
>

View File

@@ -59,7 +59,6 @@ function TTSButton({ text, voice, speed }: TTSButtonProps) {
// Surface streaming voice playback errors to the user via toast
useEffect(() => {
if (error) {
console.error("Voice playback error:", error);
toast.error(error);
}
}, [error]);

View File

@@ -1,4 +1,4 @@
import React, { useCallback, useMemo, JSX } from "react";
import React, { useCallback, useEffect, useRef, useMemo, JSX } from "react";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
@@ -17,10 +17,79 @@ import { transformLinkUri, cn } from "@/lib/utils";
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
/** Table wrapper that detects horizontal overflow and shows a fade + scrollbar. */
interface ScrollableTableProps
extends React.TableHTMLAttributes<HTMLTableElement> {
children: React.ReactNode;
}
export function ScrollableTable({
className,
children,
...props
}: ScrollableTableProps) {
const scrollRef = useRef<HTMLDivElement>(null);
const wrapRef = useRef<HTMLDivElement>(null);
const tableRef = useRef<HTMLTableElement>(null);
useEffect(() => {
const el = scrollRef.current;
const wrap = wrapRef.current;
const table = tableRef.current;
if (!el || !wrap) return;
const check = () => {
const overflows = el.scrollWidth > el.clientWidth;
const atEnd = el.scrollLeft + el.clientWidth >= el.scrollWidth - 2;
wrap.dataset.overflows = overflows && !atEnd ? "true" : "false";
el.dataset.scrolled = el.scrollLeft > 0 ? "true" : "false";
};
check();
el.addEventListener("scroll", check, { passive: true });
// Observe both the scroll container (parent resize) and the table
// itself (content growth during streaming).
const ro = new ResizeObserver(check);
ro.observe(el);
if (table) ro.observe(table);
return () => {
el.removeEventListener("scroll", check);
ro.disconnect();
};
}, []);
return (
<div ref={wrapRef} className="markdown-table-card">
<div ref={scrollRef} className="markdown-table-breakout">
<table
ref={tableRef}
className={cn(
className,
"min-w-full !my-0 [&_th]:whitespace-nowrap [&_td]:whitespace-nowrap"
)}
{...props}
>
{children}
</table>
</div>
</div>
);
}
/**
* Processes content for markdown rendering by handling code blocks and LaTeX
*/
export const processContent = (content: string): string => {
// Strip incomplete citation links at the end of streaming content.
// During typewriter animation, [[N]](url) is revealed character by character.
// ReactMarkdown can't parse an incomplete link and renders it as raw text.
// This regex removes any trailing partial citation pattern so only complete
// links are passed to the markdown parser.
content = content.replace(/\[\[\d+\]\]\([^)]*$/, "");
// Also strip a lone [[ or [[N] or [[N]] at the very end (before the URL part arrives)
content = content.replace(/\[\[(?:\d+\]?\]?)?$/, "");
const codeBlockRegex = /```(\w*)\n[\s\S]*?```|```[\s\S]*?$/g;
const matches = content.match(codeBlockRegex);
@@ -127,11 +196,9 @@ export const useMarkdownComponents = (
},
table: ({ node, className, children, ...props }: any) => {
return (
<div className="markdown-table-breakout">
<table className={cn(className, "min-w-full")} {...props}>
{children}
</table>
</div>
<ScrollableTable className={className} {...props}>
{children}
</ScrollableTable>
);
},
code: ({ node, className, children }: any) => {

View File

@@ -1,6 +1,14 @@
import React, { useEffect, useMemo, useRef, useState } from "react";
import Text from "@/refresh-components/texts/Text";
import ReactMarkdown, { Components } from "react-markdown";
import type { PluggableList } from "unified";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import rehypeHighlight from "rehype-highlight";
import rehypeKatex from "rehype-katex";
import "katex/dist/katex.min.css";
import { useTypewriter } from "@/hooks/useTypewriter";
import Text from "@/refresh-components/texts/Text";
import {
ChatPacket,
PacketType,
@@ -8,16 +16,22 @@ import {
} from "../../../services/streamingModels";
import { MessageRenderer, FullChatState } from "../interfaces";
import { isFinalAnswerComplete } from "../../../services/packetUtils";
import { useMarkdownRenderer } from "../markdownUtils";
import { processContent } from "../markdownUtils";
import { BlinkingBar } from "../../BlinkingBar";
import { useVoiceMode } from "@/providers/VoiceModeProvider";
import {
MemoizedAnchor,
MemoizedParagraph,
} from "@/app/app/message/MemoizedTextComponents";
import { extractCodeText } from "@/app/app/message/codeUtils";
import { CodeBlock } from "@/app/app/message/CodeBlock";
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
import { cn, transformLinkUri } from "@/lib/utils";
/**
* Maps a cleaned character position to the corresponding position in markdown text.
* This allows progressive reveal to work with markdown formatting.
*/
/** Maps a visible-char count to a markdown index (skips formatting chars,
* extends to word boundary). Used by the voice-sync reveal path only. */
function getRevealPosition(markdown: string, cleanChars: number): number {
// Skip patterns that don't contribute to visible character count
const skipChars = new Set(["*", "`", "#"]);
let cleanIndex = 0;
let mdIndex = 0;
@@ -25,13 +39,11 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
while (cleanIndex < cleanChars && mdIndex < markdown.length) {
const char = markdown[mdIndex];
// Skip markdown formatting characters
if (char !== undefined && skipChars.has(char)) {
mdIndex++;
continue;
}
// Handle link syntax [text](url) - skip the (url) part but count the text
if (
char === "]" &&
mdIndex + 1 < markdown.length &&
@@ -48,7 +60,6 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
mdIndex++;
}
// Extend to word boundary to avoid cutting mid-word
while (
mdIndex < markdown.length &&
markdown[mdIndex] !== " " &&
@@ -60,8 +71,15 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
return mdIndex;
}
// Control the rate of packet streaming (packets per second)
const PACKET_DELAY_MS = 10;
// Cheap streaming plugins (gfm only) → cheap per-frame parse. Full
// pipeline flips in once, at the end, for syntax highlighting + math.
const STREAMING_REMARK_PLUGINS: PluggableList = [remarkGfm];
const STREAMING_REHYPE_PLUGINS: PluggableList = [];
const FULL_REMARK_PLUGINS: PluggableList = [
remarkGfm,
[remarkMath, { singleDollarTextMath: true }],
];
const FULL_REHYPE_PLUGINS: PluggableList = [rehypeHighlight, rehypeKatex];
export const MessageTextRenderer: MessageRenderer<
ChatPacket,
@@ -78,19 +96,17 @@ export const MessageTextRenderer: MessageRenderer<
stopReason,
children,
}) => {
// If we're animating and the final answer is already complete, show more packets initially
const initialPacketCount = animate
? packets.length > 0
? 1 // Otherwise start with 1 packet
: 0
: -1; // Show all if not animating
const [displayedPacketCount, setDisplayedPacketCount] =
useState(initialPacketCount);
const lastStableSyncedContentRef = useRef("");
const lastVisibleContentRef = useRef("");
// Get voice mode context for progressive text reveal synced with audio
// Timeout guard: if TTS doesn't start within 5s of voice sync
// activating, fall back to normal streaming. Prevents permanent
// content suppression when the voice WebSocket fails to connect.
const [voiceSyncTimedOut, setVoiceSyncTimedOut] = useState(false);
const voiceSyncTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
null
);
const {
revealedCharCount,
autoPlayback,
@@ -99,7 +115,6 @@ export const MessageTextRenderer: MessageRenderer<
isAwaitingAutoPlaybackStart,
} = useVoiceMode();
// Get the full content from all packets
const fullContent = packets
.map((packet) => {
if (
@@ -114,117 +129,74 @@ export const MessageTextRenderer: MessageRenderer<
const shouldUseAutoPlaybackSync =
autoPlayback &&
!voiceSyncTimedOut &&
typeof messageNodeId === "number" &&
activeMessageNodeId === messageNodeId;
// Animation effect - gradually increase displayed packets at controlled rate
// Start/clear the timeout when voice sync activates/deactivates.
useEffect(() => {
if (!animate) {
setDisplayedPacketCount(-1); // Show all packets
return;
}
if (displayedPacketCount >= 0 && displayedPacketCount < packets.length) {
const timer = setTimeout(() => {
setDisplayedPacketCount((prev) => Math.min(prev + 1, packets.length));
}, PACKET_DELAY_MS);
return () => clearTimeout(timer);
}
}, [animate, displayedPacketCount, packets.length]);
// Reset displayed count when packet array changes significantly (e.g., new message)
useEffect(() => {
if (animate && packets.length < displayedPacketCount) {
const resetCount = isFinalAnswerComplete(packets)
? Math.min(10, packets.length)
: packets.length > 0
? 1
: 0;
setDisplayedPacketCount(resetCount);
}
}, [animate, packets.length, displayedPacketCount]);
// Only mark as complete when all packets are received AND displayed
useEffect(() => {
if (isFinalAnswerComplete(packets)) {
// If animating, wait until all packets are displayed
if (
animate &&
displayedPacketCount >= 0 &&
displayedPacketCount < packets.length
) {
return;
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
if (!voiceSyncTimeoutRef.current) {
voiceSyncTimeoutRef.current = setTimeout(() => {
setVoiceSyncTimedOut(true);
}, 5000);
}
onComplete();
} else {
// TTS started or sync deactivated — clear timeout
if (voiceSyncTimeoutRef.current) {
clearTimeout(voiceSyncTimeoutRef.current);
voiceSyncTimeoutRef.current = null;
}
if (voiceSyncTimedOut && !autoPlayback) setVoiceSyncTimedOut(false);
}
}, [packets, onComplete, animate, displayedPacketCount]);
return () => {
if (voiceSyncTimeoutRef.current) {
clearTimeout(voiceSyncTimeoutRef.current);
voiceSyncTimeoutRef.current = null;
}
};
}, [
shouldUseAutoPlaybackSync,
isAwaitingAutoPlaybackStart,
isAudioSyncActive,
voiceSyncTimedOut,
]);
// Get content based on displayed packet count or audio progress
// Normal streaming hands full text to the typewriter. Voice-sync
// paths pre-slice and bypass. If shouldUseAutoPlaybackSync is false
// (including after the 5s timeout), all paths fall through to fullContent.
const computedContent = useMemo(() => {
// Hold response in "thinking" state only while autoplay startup is pending.
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
return "";
}
// Sync text with audio only for the message currently being spoken.
if (shouldUseAutoPlaybackSync && isAudioSyncActive) {
const MIN_REVEAL_CHARS = 12;
if (revealedCharCount < MIN_REVEAL_CHARS) {
return "";
}
// Reveal text progressively based on audio progress
const revealPos = getRevealPosition(fullContent, revealedCharCount);
return fullContent.slice(0, Math.max(revealPos, 0));
}
// During an active synced turn, if sync temporarily drops, keep current reveal
// instead of jumping to full content or blanking.
if (shouldUseAutoPlaybackSync && !stopPacketSeen) {
return lastStableSyncedContentRef.current;
}
// Standard behavior when auto-playback is off
if (!animate || displayedPacketCount === -1) {
return fullContent; // Show all content
}
// Packet-based reveal (when auto-playback is disabled)
return packets
.slice(0, displayedPacketCount)
.map((packet) => {
if (
packet.obj.type === PacketType.MESSAGE_DELTA ||
packet.obj.type === PacketType.MESSAGE_START
) {
return packet.obj.content;
}
return "";
})
.join("");
return fullContent;
}, [
animate,
displayedPacketCount,
fullContent,
packets,
revealedCharCount,
autoPlayback,
isAudioSyncActive,
activeMessageNodeId,
isAwaitingAutoPlaybackStart,
messageNodeId,
shouldUseAutoPlaybackSync,
isAwaitingAutoPlaybackStart,
isAudioSyncActive,
revealedCharCount,
fullContent,
stopPacketSeen,
]);
// Keep synced text monotonic: once visible, never regress or disappear between chunks.
// Monotonic guard for voice sync + freeze on user cancel.
const content = useMemo(() => {
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
// On user cancel during live streaming, freeze at exactly what was already
// visible to prevent flicker. On history reload (animate=false), the ref
// starts empty so we must use computedContent directly.
if (wasUserCancelled && animate) {
return lastVisibleContentRef.current;
}
@@ -242,13 +214,10 @@ export const MessageTextRenderer: MessageRenderer<
return computedContent;
}
// If content shape changed unexpectedly mid-stream, prefer the stable version
// to avoid flicker/dumps.
if (!stopPacketSeen || wasUserCancelled) {
return last;
}
// For normal completed responses, allow final full content.
return computedContent;
}, [
computedContent,
@@ -258,7 +227,6 @@ export const MessageTextRenderer: MessageRenderer<
animate,
]);
// Sync the stable ref outside of useMemo to avoid side effects during render.
useEffect(() => {
if (stopReason === StopReason.USER_CANCELLED) {
return;
@@ -270,13 +238,128 @@ export const MessageTextRenderer: MessageRenderer<
}
}, [content, shouldUseAutoPlaybackSync, stopReason]);
// Track last actually rendered content so cancel can freeze without dumping buffered text.
useEffect(() => {
if (content.length > 0) {
lastVisibleContentRef.current = content;
}
}, [content]);
const isStreamingAnimationEnabled =
animate &&
!shouldUseAutoPlaybackSync &&
stopReason !== StopReason.USER_CANCELLED;
const isStreamFinished = isFinalAnswerComplete(packets);
const displayedContent = useTypewriter(content, isStreamingAnimationEnabled);
// One-way signal: stream done AND typewriter caught up. Do NOT derive
// this from "typewriter currently behind" — it oscillates mid-stream
// between packet bursts and would thrash the plugin pipeline.
const streamFullyDisplayed =
isStreamFinished && displayedContent.length >= content.length;
// Fire onComplete exactly once per mount. `onComplete` is an inline
// arrow in AgentMessage so its identity changes on every parent render;
// without this guard, each new identity would re-fire the effect once
// `streamFullyDisplayed` is true.
const onCompleteFiredRef = useRef(false);
useEffect(() => {
if (streamFullyDisplayed && !onCompleteFiredRef.current) {
onCompleteFiredRef.current = true;
onComplete();
}
}, [streamFullyDisplayed, onComplete]);
const processedContent = useMemo(
() => processContent(displayedContent),
[displayedContent]
);
// Stable-identity components for ReactMarkdown. Dynamic data (`state`,
// `processedContent`) flows through refs so the callback identities
// never change — otherwise every typewriter tick would invalidate
// React reconciliation on the markdown subtree.
const stateRef = useRef(state);
stateRef.current = state;
const processedContentRef = useRef(processedContent);
processedContentRef.current = processedContent;
const markdownComponents = useMemo<Components>(
() => ({
a: ({ href, children }) => {
const s = stateRef.current;
const imageFileId = extractChatImageFileId(
href,
String(children ?? "")
);
if (imageFileId) {
return (
<InMessageImage
fileId={imageFileId}
fileName={String(children ?? "")}
/>
);
}
return (
<MemoizedAnchor
updatePresentingDocument={s?.setPresentingDocument || (() => {})}
docs={s?.docs || []}
userFiles={s?.userFiles || []}
citations={s?.citations}
href={href}
>
{children}
</MemoizedAnchor>
);
},
p: ({ children }) => (
<MemoizedParagraph className="font-main-content-body">
{children}
</MemoizedParagraph>
),
pre: ({ children }) => <>{children}</>,
b: ({ className, children }) => (
<span className={className}>{children}</span>
),
ul: ({ className, children, ...rest }) => (
<ul className={className} {...rest}>
{children}
</ul>
),
ol: ({ className, children, ...rest }) => (
<ol className={className} {...rest}>
{children}
</ol>
),
li: ({ className, children, ...rest }) => (
<li className={className} {...rest}>
{children}
</li>
),
table: ({ className, children, ...rest }) => (
<div className="markdown-table-breakout">
<table className={cn(className, "min-w-full")} {...rest}>
{children}
</table>
</div>
),
code: ({ node, className, children }) => {
const codeText = extractCodeText(
node,
processedContentRef.current,
children
);
return (
<CodeBlock className={className} codeText={codeText}>
{children}
</CodeBlock>
);
},
}),
[]
);
const shouldShowThinkingPlaceholder =
shouldUseAutoPlaybackSync &&
isAwaitingAutoPlaybackStart &&
@@ -292,16 +375,16 @@ export const MessageTextRenderer: MessageRenderer<
!stopPacketSeen;
const shouldShowCursor =
content.length > 0 &&
(!stopPacketSeen ||
displayedContent.length > 0 &&
((isStreamingAnimationEnabled && !streamFullyDisplayed) ||
(!isStreamingAnimationEnabled && !stopPacketSeen) ||
(shouldUseAutoPlaybackSync && content.length < fullContent.length));
const { renderedContent } = useMarkdownRenderer(
// the [*]() is a hack to show a blinking dot when the packet is not complete
shouldShowCursor ? content + " [*]() " : content,
state,
"font-main-content-body"
);
// `[*]() ` is rendered by the anchor component as an inline blinking
// caret, keeping it flush with the trailing character.
const markdownInput = shouldShowCursor
? processedContent + " [*]() "
: processedContent;
return children([
{
@@ -312,8 +395,26 @@ export const MessageTextRenderer: MessageRenderer<
<Text as="span" secondaryBody text04 className="italic">
Thinking
</Text>
) : content.length > 0 ? (
<>{renderedContent}</>
) : displayedContent.length > 0 ? (
<div dir="auto">
<ReactMarkdown
className="prose prose-onyx font-main-content-body max-w-full"
components={markdownComponents}
remarkPlugins={
streamFullyDisplayed
? FULL_REMARK_PLUGINS
: STREAMING_REMARK_PLUGINS
}
rehypePlugins={
streamFullyDisplayed
? FULL_REHYPE_PLUGINS
: STREAMING_REHYPE_PLUGINS
}
urlTransform={transformLinkUri}
>
{markdownInput}
</ReactMarkdown>
</div>
) : (
<BlinkingBar addMargin />
),

View File

@@ -34,7 +34,8 @@ export const PROVIDERS: ProviderConfig[] = [
providerName: LLMProviderName.ANTHROPIC,
recommended: true,
models: [
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
{ name: "claude-opus-4-7", label: "Claude Opus 4.7", recommended: true },
{ name: "claude-opus-4-6", label: "Claude Opus 4.6" },
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
],
apiKeyPlaceholder: "sk-ant-...",

View File

@@ -5,12 +5,12 @@
export interface BuildLlmSelection {
providerName: string; // e.g., "build-mode-anthropic" (LLMProviderDescriptor.name)
provider: string; // e.g., "anthropic"
modelName: string; // e.g., "claude-opus-4-6"
modelName: string; // e.g., "claude-opus-4-7"
}
// Priority order for smart default LLM selection
const LLM_SELECTION_PRIORITY = [
{ provider: "anthropic", modelName: "claude-opus-4-6" },
{ provider: "anthropic", modelName: "claude-opus-4-7" },
{ provider: "openai", modelName: "gpt-5.2" },
{ provider: "openrouter", modelName: "minimax/minimax-m2.1" },
] as const;
@@ -63,10 +63,11 @@ export function getDefaultLlmSelection(
export const RECOMMENDED_BUILD_MODELS = {
preferred: {
provider: "anthropic",
modelName: "claude-opus-4-6",
displayName: "Claude Opus 4.6",
modelName: "claude-opus-4-7",
displayName: "Claude Opus 4.7",
},
alternatives: [
{ provider: "anthropic", modelName: "claude-opus-4-6" },
{ provider: "anthropic", modelName: "claude-sonnet-4-6" },
{ provider: "openai", modelName: "gpt-5.2" },
{ provider: "openai", modelName: "gpt-5.1-codex" },
@@ -148,7 +149,8 @@ export const BUILD_MODE_PROVIDERS: BuildModeProvider[] = [
providerName: "anthropic",
recommended: true,
models: [
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
{ name: "claude-opus-4-7", label: "Claude Opus 4.7", recommended: true },
{ name: "claude-opus-4-6", label: "Claude Opus 4.6" },
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
],
apiKeyPlaceholder: "sk-ant-...",

View File

@@ -320,7 +320,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
onSubmit({
message: submittedMessage,
currentMessageFiles: currentMessageFiles,
deepResearch: deepResearchEnabled,
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
additionalContext,
selectedModels,
});
@@ -332,7 +332,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
onSubmit({
message: chatMessage,
currentMessageFiles: currentMessageFiles,
deepResearch: deepResearchEnabled,
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
additionalContext,
selectedModels,
});
@@ -370,10 +370,16 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
onSubmit({
message: lastUserMsg.message,
currentMessageFiles: currentMessageFiles,
deepResearch: deepResearchEnabled,
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
messageIdToResend: lastUserMsg.messageId,
});
}, [messageHistory, onSubmit, currentMessageFiles, deepResearchEnabled]);
}, [
messageHistory,
onSubmit,
currentMessageFiles,
deepResearchEnabled,
multiModel.isMultiModelActive,
]);
// Start a new chat session in the side panel
const handleNewChat = useCallback(() => {
@@ -516,8 +522,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
ref={inputRef}
className={cn(
"w-full flex flex-col",
!isSidePanel &&
"max-w-[var(--app-page-main-content-width)] px-4"
!isSidePanel && "max-w-[var(--app-page-main-content-width)]"
)}
>
{hasMessages && liveAgent && !llmManager.isLoadingProviders && (
@@ -535,6 +540,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
ref={chatInputBarRef}
deepResearchEnabled={deepResearchEnabled}
toggleDeepResearch={toggleDeepResearch}
isMultiModelActive={multiModel.isMultiModelActive}
filterManager={filterManager}
llmManager={llmManager}
initialMessage={message}

View File

@@ -3,7 +3,7 @@
import { ValidSources } from "@/lib/types";
import { SourceIcon } from "./SourceIcon";
import { useState } from "react";
import { OnyxIcon } from "./icons/icons";
import { GithubIcon, OnyxIcon } from "./icons/icons";
export function WebResultIcon({
url,
@@ -23,6 +23,8 @@ export function WebResultIcon({
<>
{hostname.includes("onyx.app") ? (
<OnyxIcon size={size} className="dark:text-[#fff] text-[#000]" />
) : hostname === "github.com" || hostname.endsWith(".github.com") ? (
<GithubIcon size={size} />
) : !error ? (
<img
className="my-0 rounded-full py-0"

View File

@@ -46,6 +46,7 @@ import freshdeskIcon from "@public/Freshdesk.png";
import geminiSVG from "@public/Gemini.svg";
import gitbookDarkIcon from "@public/GitBookDark.png";
import gitbookLightIcon from "@public/GitBookLight.png";
import githubDarkIcon from "@public/GithubDarkMode.png";
import githubLightIcon from "@public/Github.png";
import gongIcon from "@public/Gong.png";
import googleIcon from "@public/Google.png";
@@ -855,7 +856,7 @@ export const GitbookIcon = createLogoIcon(gitbookDarkIcon, {
darkSrc: gitbookLightIcon,
});
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true,
darkSrc: githubDarkIcon,
});
export const GitlabIcon = createLogoIcon(gitlabIcon);
export const GmailIcon = createLogoIcon(gmailIcon);

View File

@@ -644,6 +644,7 @@ export default function useChatController({
});
node.modelDisplayName = model.displayName;
node.overridden_model = model.modelName;
node.is_generating = true;
return node;
});
}
@@ -711,6 +712,13 @@ export default function useChatController({
? selectedModels?.map((m) => m.displayName) ?? []
: [];
// rAF-batched flush state. One Zustand write per frame instead of
// one per packet.
const dirtyModelIndices = new Set<number>();
let singleModelDirty = false;
let userNodeDirty = false;
let pendingFlush = false;
/** Build a non-errored multi-model assistant node for upsert. */
function buildAssistantNodeUpdate(
idx: number,
@@ -740,16 +748,124 @@ export default function useChatController({
};
}
/** Build updated nodes for all non-errored models. */
function buildNonErroredNodes(overrides?: Partial<Message>): Message[] {
/** With `onlyDirty`, rebuilds only those model nodes — unchanged
* siblings keep their stable Message ref so React memo short-circuits. */
function buildNonErroredNodes(
overrides?: Partial<Message>,
onlyDirty?: Set<number> | null
): Message[] {
const nodes: Message[] = [];
for (let idx = 0; idx < initialAssistantNodes.length; idx++) {
if (erroredModelIndices.has(idx)) continue;
if (onlyDirty && !onlyDirty.has(idx)) continue;
nodes.push(buildAssistantNodeUpdate(idx, overrides));
}
return nodes;
}
/** Flush accumulated packet state into the tree as one Zustand
* update. No-op when nothing is pending. */
function flushPendingUpdates() {
if (!pendingFlush) return;
pendingFlush = false;
parentMessage =
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
let messagesToUpsert: Message[];
if (isMultiModel) {
if (dirtyModelIndices.size === 0 && !userNodeDirty) return;
const dirtySnapshot = new Set(dirtyModelIndices);
dirtyModelIndices.clear();
const dirtyNodes = buildNonErroredNodes(undefined, dirtySnapshot);
if (userNodeDirty) {
userNodeDirty = false;
// Read current user node to preserve childrenNodeIds
// (initialUserNode's are stale from creation time).
const currentUserNode =
currentMessageTreeLocal.get(initialUserNode.nodeId) ||
initialUserNode;
const updatedUserNode: Message = {
...currentUserNode,
messageId: newUserMessageId ?? undefined,
files: files,
};
messagesToUpsert = [updatedUserNode, ...dirtyNodes];
} else {
messagesToUpsert = dirtyNodes;
}
if (messagesToUpsert.length === 0) return;
} else {
if (!singleModelDirty) return;
singleModelDirty = false;
messagesToUpsert = [
{
...initialUserNode,
messageId: newUserMessageId ?? undefined,
files: files,
},
{
...initialAgentNode,
messageId: newAgentMessageId ?? undefined,
message: error || answer,
type: error ? "error" : "assistant",
retrievalType,
query: finalMessage?.rephrased_query || query,
documents: documents,
citations: finalMessage?.citations || citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCall: finalMessage?.tool_call || toolCall,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
packets: packets,
packetCount: packets.length,
processingDurationSeconds:
finalMessage?.processing_duration_seconds ??
(() => {
const startTime = useChatSessionStore
.getState()
.getStreamingStartTime(frozenSessionId);
return startTime
? Math.floor((Date.now() - startTime) / 1000)
: undefined;
})(),
},
];
}
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: messagesToUpsert,
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId!,
});
}
/** Awaits next animation frame (or a setTimeout fallback when the
* tab is hidden — rAF is paused in background tabs, which would
* otherwise hang the stream loop here), then flushes. Aligns
* React updates with the paint cycle when visible. */
function flushViaRAF(): Promise<void> {
return new Promise<void>((resolve) => {
let done = false;
const flush = () => {
if (done) return;
done = true;
flushPendingUpdates();
resolve();
};
requestAnimationFrame(flush);
// Fallback for hidden tabs where rAF is paused. Throttled to
// ~1s by browsers, matching the previous setTimeout(500) cadence.
setTimeout(flush, 100);
});
}
let streamSucceeded = false;
try {
@@ -836,7 +952,12 @@ export default function useChatController({
await delay(50);
while (!stack.isComplete || !stack.isEmpty()) {
if (stack.isEmpty()) {
await delay(0.5);
// Flush the burst on the next paint, or idle briefly.
if (pendingFlush) {
await flushViaRAF();
} else {
await delay(0.5);
}
}
if (!stack.isEmpty() && !controller.signal.aborted) {
@@ -860,6 +981,7 @@ export default function useChatController({
if ((packet as MessageResponseIDInfo).user_message_id) {
newUserMessageId = (packet as MessageResponseIDInfo)
.user_message_id;
userNodeDirty = true;
// Track extension queries in PostHog (reuses isExtension/extensionContext from above)
if (isExtension) {
@@ -898,6 +1020,8 @@ export default function useChatController({
modelDisplayNames[mi] = slot.model_name;
}
}
userNodeDirty = true;
pendingFlush = true;
continue;
}
@@ -909,6 +1033,7 @@ export default function useChatController({
!files.some((existingFile) => existingFile.id === newFile.id)
);
files = files.concat(newUserFiles);
if (newUserFiles.length > 0) userNodeDirty = true;
}
if (Object.hasOwn(packet, "file_ids")) {
@@ -928,15 +1053,20 @@ export default function useChatController({
// In multi-model mode, route per-model errors to the specific model's
// node instead of killing the entire stream. Other models keep streaming.
if (isMultiModel && streamingError.details?.model_index != null) {
const errorModelIndex = streamingError.details
.model_index as number;
if (isMultiModel) {
// Multi-model: isolate the error to its panel. Never throw
// or set global error state — other models keep streaming.
const errorModelIndex = streamingError.details?.model_index as
| number
| undefined;
if (
errorModelIndex != null &&
errorModelIndex >= 0 &&
errorModelIndex < initialAssistantNodes.length
) {
const errorNode = initialAssistantNodes[errorModelIndex]!;
erroredModelIndices.add(errorModelIndex);
dirtyModelIndices.delete(errorModelIndex);
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: [
{
@@ -963,8 +1093,15 @@ export default function useChatController({
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId!,
});
} else {
// Error without model_index in multi-model — can't route
// to a specific panel. Log and continue; the stream loop
// stays alive for other models.
console.warn(
"Multi-model error without model_index:",
streamingError.error
);
}
// Skip the normal per-packet upsert — we already upserted the error node
continue;
} else {
// Single-model: kill the stream
@@ -993,19 +1130,21 @@ export default function useChatController({
if (isMultiModel) {
// Multi-model: route packet by placement.model_index.
// OverallStop (type "stop") has model_index=null — it's a global
// terminal packet that must be delivered to ALL models so each
// panel's AgentMessage sees the stop and exits "Thinking..." state.
// OverallStop (type "stop") has model_index=null — it's a
// global terminal packet that must be delivered to ALL
// models so each panel's AgentMessage sees the stop and
// exits "Thinking..." state.
const isGlobalStop =
packetObj.type === "stop" &&
typedPacket.placement?.model_index == null;
if (isGlobalStop) {
for (let mi = 0; mi < packetsPerModel.length; mi++) {
packetsPerModel[mi] = [
...packetsPerModel[mi]!,
typedPacket,
];
// Mutated in place — change detection uses packetCount, not array identity.
packetsPerModel[mi]!.push(typedPacket);
if (!erroredModelIndices.has(mi)) {
dirtyModelIndices.add(mi);
}
}
}
@@ -1015,10 +1154,10 @@ export default function useChatController({
modelIndex >= 0 &&
modelIndex < packetsPerModel.length
) {
packetsPerModel[modelIndex] = [
...packetsPerModel[modelIndex]!,
typedPacket,
];
packetsPerModel[modelIndex]!.push(typedPacket);
if (!erroredModelIndices.has(modelIndex)) {
dirtyModelIndices.add(modelIndex);
}
if (packetObj.type === "citation_info") {
const citationInfo = packetObj as {
@@ -1048,6 +1187,7 @@ export default function useChatController({
// Single-model
packets.push(typedPacket);
packetsVersion++;
singleModelDirty = true;
if (packetObj.type === "citation_info") {
const citationInfo = packetObj as {
@@ -1074,73 +1214,16 @@ export default function useChatController({
console.warn("Unknown packet:", JSON.stringify(packet));
}
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
parentMessage =
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
// Build the messages to upsert based on single vs multi-model mode
let messagesToUpsertInLoop: Message[];
if (isMultiModel) {
// Read the current user node from the tree to preserve childrenNodeIds
// (initialUserNode has stale/empty children from creation time).
const currentUserNode =
currentMessageTreeLocal.get(initialUserNode.nodeId) ||
initialUserNode;
const updatedUserNode: Message = {
...currentUserNode,
messageId: newUserMessageId ?? undefined,
files: files,
};
messagesToUpsertInLoop = [
updatedUserNode,
...buildNonErroredNodes(),
];
} else {
messagesToUpsertInLoop = [
{
...initialUserNode,
messageId: newUserMessageId ?? undefined,
files: files,
},
{
...initialAgentNode,
messageId: newAgentMessageId ?? undefined,
message: error || answer,
type: error ? "error" : "assistant",
retrievalType,
query: finalMessage?.rephrased_query || query,
documents: documents,
citations: finalMessage?.citations || citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCall: finalMessage?.tool_call || toolCall,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
packets: packets,
packetCount: packets.length,
processingDurationSeconds:
finalMessage?.processing_duration_seconds ??
(() => {
const startTime = useChatSessionStore
.getState()
.getStreamingStartTime(frozenSessionId);
return startTime
? Math.floor((Date.now() - startTime) / 1000)
: undefined;
})(),
},
];
}
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: messagesToUpsertInLoop,
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId!,
});
// Mark dirty — flushViaRAF coalesces bursts into one React update per frame.
if (!isMultiModel) singleModelDirty = true;
pendingFlush = true;
}
}
// Flush any tail state from the final packet(s) before declaring
// the stream complete. Without this, the last ≤1 frame of packets
// could get stranded in local state.
flushPendingUpdates();
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
// catch block replaces the thinking placeholder with an error message.
if (stack.error) {
@@ -1174,6 +1257,7 @@ export default function useChatController({
errorCode,
isRetryable,
errorDetails,
is_generating: false,
})
: [
{

View File

@@ -106,9 +106,23 @@ export default function useMultiModelChat(
[currentLlmModel]
);
const removeModel = useCallback((index: number) => {
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
}, []);
const removeModel = useCallback(
(index: number) => {
const next = selectedModels.filter((_, i) => i !== index);
// When dropping to single-model, switch llmManager to the surviving
// model so it becomes the active model instead of reverting to the
// user's default.
if (next.length === 1 && next[0]) {
llmManager.updateCurrentLlm({
name: next[0].name,
provider: next[0].provider,
modelName: next[0].modelName,
});
}
setSelectedModels(next);
},
[selectedModels, llmManager]
);
const replaceModel = useCallback(
(index: number, model: SelectedModel) => {

View File

@@ -48,6 +48,7 @@ describe("useSettings", () => {
anonymous_user_enabled: false,
invite_only_enabled: false,
deep_research_enabled: true,
multi_model_chat_enabled: true,
temperature_override_enabled: true,
query_history_type: QueryHistoryType.NORMAL,
});
@@ -65,6 +66,7 @@ describe("useSettings", () => {
anonymous_user_enabled: false,
invite_only_enabled: false,
deep_research_enabled: true,
multi_model_chat_enabled: true,
temperature_override_enabled: true,
query_history_type: QueryHistoryType.NORMAL,
};

View File

@@ -23,6 +23,7 @@ const DEFAULT_SETTINGS = {
anonymous_user_enabled: false,
invite_only_enabled: false,
deep_research_enabled: true,
multi_model_chat_enabled: true,
temperature_override_enabled: true,
query_history_type: QueryHistoryType.NORMAL,
} satisfies Settings;

View File

@@ -0,0 +1,134 @@
import { useEffect, useMemo, useRef, useState } from "react";
// Fixed reveal rate — NOT adaptive. Any ceil(delta/N) formula produces
// visible chunks on burst packet arrivals. 1 = 60 cps, 2 = 120 cps.
const CHARS_PER_FRAME = 3;
/**
* Reveals `target` one character at a time on each animation frame.
* When `enabled` is false (historical messages), snaps to full on mount.
* The rAF loop pauses once caught up and resumes when `target` grows.
*/
export function useTypewriter(target: string, enabled: boolean): string {
// Ref so the rAF loop reads latest length without restarting.
const targetRef = useRef(target);
targetRef.current = target;
// Mirror `enabled` so the restart effect can short-circuit when the
// caller has turned animation off (e.g. voice-mode, where display is
// driven by audio position — the typewriter must stay idle and not
// animate a jump after audio ends).
const enabledRef = useRef(enabled);
enabledRef.current = enabled;
// `enabled` controls initial state: animate from 0 vs snap to full for
// history/voice. Transitions mid-stream are handled via enabledRef in
// the restart effect so a flip to false doesn't dump the buffered tail
// *and* doesn't spin up the rAF loop on later growth.
const [displayedLength, setDisplayedLength] = useState<number>(
enabled ? 0 : target.length
);
// Mirror displayedLength in a ref so the rAF loop can read the latest
// value without stale-closure issues AND without needing a functional
// state updater (which must be pure — no ref mutations inside).
const displayedLengthRef = useRef(displayedLength);
// Clamp (not reset) on target shrink — preserves already-revealed chars
// across user-cancel freeze and regeneration.
const prevTargetLengthRef = useRef(target.length);
useEffect(() => {
if (target.length < prevTargetLengthRef.current) {
const clamped = Math.min(displayedLengthRef.current, target.length);
displayedLengthRef.current = clamped;
setDisplayedLength(clamped);
}
prevTargetLengthRef.current = target.length;
}, [target.length]);
// Self-scheduling rAF loop. Pauses when caught up so idle/historical
// messages don't run a 60fps no-op updater for their entire lifetime.
const rafIdRef = useRef<number | null>(null);
const runningRef = useRef(false);
const startLoopRef = useRef<(() => void) | null>(null);
useEffect(() => {
const tick = () => {
const targetLen = targetRef.current.length;
const prev = displayedLengthRef.current;
if (prev >= targetLen) {
// Caught up — pause the loop. The sibling effect below will
// restart it when `target` grows.
runningRef.current = false;
rafIdRef.current = null;
return;
}
const next = Math.min(prev + CHARS_PER_FRAME, targetLen);
displayedLengthRef.current = next;
setDisplayedLength(next);
rafIdRef.current = requestAnimationFrame(tick);
};
const start = () => {
if (runningRef.current) return;
// Animation disabled — snap to full and stay idle. This is the
// voice-mode path where content is driven by audio position, and
// any "gap" (e.g. user stops audio early) must jump instantly
// instead of animating a 1500-char typewriter burst.
if (!enabledRef.current) {
const targetLen = targetRef.current.length;
if (displayedLengthRef.current !== targetLen) {
displayedLengthRef.current = targetLen;
setDisplayedLength(targetLen);
}
return;
}
runningRef.current = true;
rafIdRef.current = requestAnimationFrame(tick);
};
startLoopRef.current = start;
if (targetRef.current.length > displayedLengthRef.current) {
start();
}
return () => {
runningRef.current = false;
if (rafIdRef.current !== null) {
cancelAnimationFrame(rafIdRef.current);
rafIdRef.current = null;
}
startLoopRef.current = null;
};
}, []);
// Restart the loop when target grows past what's currently displayed.
useEffect(() => {
if (target.length > displayedLength && startLoopRef.current) {
startLoopRef.current();
}
}, [target.length, displayedLength]);
// When the user navigates away and back (tab switch, window focus),
// snap to all collected content so they see the full response immediately.
useEffect(() => {
const handleVisibility = () => {
if (document.visibilityState === "visible") {
const targetLen = targetRef.current.length;
if (displayedLengthRef.current < targetLen) {
displayedLengthRef.current = targetLen;
setDisplayedLength(targetLen);
}
}
};
document.addEventListener("visibilitychange", handleVisibility);
return () =>
document.removeEventListener("visibilitychange", handleVisibility);
}, []);
return useMemo(
() => target.slice(0, Math.min(displayedLength, target.length)),
[target, displayedLength]
);
}

View File

@@ -27,6 +27,7 @@ export interface Settings {
query_history_type: QueryHistoryType;
deep_research_enabled?: boolean;
multi_model_chat_enabled?: boolean;
search_ui_enabled?: boolean;
// Image processing settings

View File

@@ -173,8 +173,13 @@ function AttachmentItemLayout({
rightChildren,
}: AttachmentItemLayoutProps) {
return (
<Section flexDirection="row" gap={0.25} padding={0.25}>
<div className={cn("h-[2.25rem] aspect-square rounded-08")}>
<Section
flexDirection="row"
justifyContent="start"
gap={0.25}
padding={0.25}
>
<div className={cn("h-[2.25rem] aspect-square rounded-08 flex-shrink-0")}>
<Section>
<div
className="attachment-button__icon-wrapper"
@@ -189,6 +194,7 @@ function AttachmentItemLayout({
justifyContent="between"
alignItems="center"
gap={1.5}
className="min-w-0"
>
<div data-testid="attachment-item-title" className="flex-1 min-w-0">
<Content

View File

@@ -9,6 +9,7 @@ import { useField, useFormikContext } from "formik";
import { Section } from "@/layouts/general-layouts";
import { Content } from "@opal/layouts";
import Label from "@/refresh-components/form/Label";
import type { TagProps } from "@opal/components/tag/components";
interface OrientationLayoutProps {
name?: string;
@@ -16,6 +17,8 @@ interface OrientationLayoutProps {
nonInteractive?: boolean;
children?: React.ReactNode;
title: string | RichStr;
/** Tag rendered inline beside the title (passed through to Content). */
tag?: TagProps;
description?: string | RichStr;
suffix?: "optional" | (string & {});
sizePreset?: "main-content" | "main-ui";
@@ -128,6 +131,7 @@ function HorizontalInputLayout({
children,
center,
title,
tag,
description,
suffix,
sizePreset = "main-content",
@@ -144,6 +148,7 @@ function HorizontalInputLayout({
title={title}
description={description}
suffix={suffix}
tag={tag}
sizePreset={sizePreset}
variant="section"
widthVariant="full"

View File

@@ -694,6 +694,25 @@ export function useLlmManager(
prevAgentIdRef.current = liveAgent?.id;
}, [liveAgent?.id]);
// Clear manual override when arriving at a *different* existing session
// from any previously-seen defined session. Tracks only the last
// *defined* session id so a round-trip through new-chat (A → undefined
// → B) still resets, while A → undefined (new-chat) preserves it.
const prevDefinedSessionIdRef = useRef<string | undefined>(undefined);
useEffect(() => {
const nextId = currentChatSession?.id;
if (
nextId !== undefined &&
prevDefinedSessionIdRef.current !== undefined &&
nextId !== prevDefinedSessionIdRef.current
) {
setUserHasManuallyOverriddenLLM(false);
}
if (nextId !== undefined) {
prevDefinedSessionIdRef.current = nextId;
}
}, [currentChatSession?.id]);
function getValidLlmDescriptor(
modelName: string | null | undefined
): LlmDescriptor {
@@ -715,8 +734,9 @@ export function useLlmManager(
if (llmProviders === undefined || llmProviders === null) {
resolved = manualLlm;
} else if (userHasManuallyOverriddenLLM && !currentChatSession) {
// User has overridden in this session and switched to a new session
} else if (userHasManuallyOverriddenLLM) {
// Manual override wins over session's `current_alternate_model`.
// Cleared on cross-session navigation by the effect above.
resolved = manualLlm;
} else if (currentChatSession?.current_alternate_model) {
resolved = getValidLlmDescriptorForProviders(
@@ -728,8 +748,6 @@ export function useLlmManager(
liveAgent.llm_model_version_override,
llmProviders
);
} else if (userHasManuallyOverriddenLLM) {
resolved = manualLlm;
} else if (user?.preferences?.default_model) {
resolved = getValidLlmDescriptorForProviders(
user.preferences.default_model,

View File

@@ -53,18 +53,17 @@ export class HTTPStreamingTTSPlayer {
// Create abort controller for this request
this.abortController = new AbortController();
// Build URL with query params
const params = new URLSearchParams();
params.set("text", text);
if (voice) params.set("voice", voice);
params.set("speed", speed.toString());
const url = `${this.getAPIUrl()}?${params}`;
const url = this.getAPIUrl();
const body = JSON.stringify({
text,
...(voice && { voice }),
speed,
});
// Check if MediaSource is supported
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
// Fallback to simple buffered playback
return this.fallbackSpeak(url);
return this.fallbackSpeak(url, body);
}
// Create MediaSource and audio element
@@ -129,15 +128,21 @@ export class HTTPStreamingTTSPlayer {
try {
const response = await fetch(url, {
method: "POST",
headers: { "Content-Type": "application/json" },
body,
signal: this.abortController.signal,
credentials: "include", // Include cookies for authentication
credentials: "include",
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`TTS request failed: ${response.status} - ${errorText}`
);
let message = `TTS request failed (${response.status})`;
try {
const errorJson = await response.json();
if (errorJson.detail) message = errorJson.detail;
} catch {
// response wasn't JSON — use status text
}
throw new Error(message);
}
const reader = response.body?.getReader();
@@ -242,16 +247,24 @@ export class HTTPStreamingTTSPlayer {
* Fallback for browsers that don't support MediaSource Extensions.
* Buffers all audio before playing.
*/
private async fallbackSpeak(url: string): Promise<void> {
private async fallbackSpeak(url: string, body: string): Promise<void> {
const response = await fetch(url, {
method: "POST",
headers: { "Content-Type": "application/json" },
body,
signal: this.abortController?.signal,
credentials: "include", // Include cookies for authentication
credentials: "include",
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`TTS request failed: ${response.status} - ${errorText}`);
let message = `TTS request failed (${response.status})`;
try {
const errorJson = await response.json();
if (errorJson.detail) message = errorJson.detail;
} catch {
// response wasn't JSON — use status text
}
throw new Error(message);
}
const audioData = await response.arrayBuffer();

View File

@@ -1,6 +1,7 @@
"use client";
import { useEffect, useRef, useState } from "react";
import copy from "copy-to-clipboard";
import { Button, ButtonProps } from "@opal/components";
import { SvgAlertTriangle, SvgCheck, SvgCopy } from "@opal/icons";
@@ -40,26 +41,19 @@ export default function CopyIconButton({
}
try {
// Check if Clipboard API is available
if (!navigator.clipboard) {
throw new Error("Clipboard API not available");
}
// If HTML content getter is provided, copy both HTML and plain text
if (getHtmlContent) {
if (navigator.clipboard && getHtmlContent) {
const htmlContent = getHtmlContent();
const clipboardItem = new ClipboardItem({
"text/html": new Blob([htmlContent], { type: "text/html" }),
"text/plain": new Blob([text], { type: "text/plain" }),
});
await navigator.clipboard.write([clipboardItem]);
}
// Default: plain text only
else {
} else if (navigator.clipboard) {
await navigator.clipboard.writeText(text);
} else if (!copy(text)) {
throw new Error("copy-to-clipboard returned false");
}
// Show "copied" state
setCopyState("copied");
} catch (err) {
console.error("Failed to copy:", err);

View File

@@ -4,8 +4,9 @@ import { useState, useMemo, useRef } from "react";
import Popover from "@/refresh-components/Popover";
import { LlmManager } from "@/lib/hooks";
import { getModelIcon } from "@/lib/llmConfig";
import { Button, SelectButton, OpenButton } from "@opal/components";
import { Button, SelectButton } from "@opal/components";
import { SvgPlusCircle, SvgX } from "@opal/icons";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { LLMOption } from "@/refresh-components/popovers/interfaces";
import ModelListContent from "@/refresh-components/popovers/ModelListContent";
import Separator from "@/refresh-components/Separator";
@@ -44,8 +45,12 @@ export default function ModelSelector({
// Virtual anchor ref — points to the clicked pill so the popover positions above it
const anchorRef = useRef<HTMLElement | null>(null);
const settings = useSettingsContext();
const multiModelAllowed =
settings?.settings?.multi_model_chat_enabled ?? true;
const isMultiModel = selectedModels.length > 1;
const atMax = selectedModels.length >= MAX_MODELS;
const atMax = selectedModels.length >= MAX_MODELS || !multiModelAllowed;
const selectedKeys = useMemo(
() => new Set(selectedModels.map((m) => modelKey(m.provider, m.modelName))),
@@ -104,7 +109,10 @@ export default function ModelSelector({
onRemove(existingIndex);
} else if (!atMax) {
onAdd(model);
setOpen(false);
// Close the popover only when we've reached the max model count
if (selectedModels.length + 1 >= MAX_MODELS) {
setOpen(false);
}
}
};
@@ -158,23 +166,13 @@ export default function ModelSelector({
model.modelName
);
if (!isMultiModel) {
return (
<OpenButton
key={modelKey(model.provider, model.modelName)}
icon={ProviderIcon}
onClick={(e: React.MouseEvent) =>
handlePillClick(index, e.currentTarget as HTMLElement)
}
>
{model.displayName}
</OpenButton>
);
}
return (
<div
key={modelKey(model.provider, model.modelName)}
key={
isMultiModel
? modelKey(model.provider, model.modelName)
: "single-model-pill"
}
className="flex items-center"
>
{index > 0 && (
@@ -186,23 +184,24 @@ export default function ModelSelector({
)}
<SelectButton
icon={ProviderIcon}
rightIcon={SvgX}
rightIcon={isMultiModel ? SvgX : undefined}
state="empty"
variant="select-tinted"
interaction="hover"
variant="select-input"
size="lg"
onClick={(e: React.MouseEvent) => {
const target = e.target as HTMLElement;
const btn = e.currentTarget as HTMLElement;
const icons = btn.querySelectorAll(
".interactive-foreground-icon"
);
const lastIcon = icons[icons.length - 1];
if (lastIcon && lastIcon.contains(target)) {
onRemove(index);
} else {
handlePillClick(index, btn);
if (isMultiModel) {
const target = e.target as HTMLElement;
const btn = e.currentTarget as HTMLElement;
const icons = btn.querySelectorAll(
".interactive-foreground-icon"
);
const lastIcon = icons[icons.length - 1];
if (lastIcon && lastIcon.contains(target)) {
onRemove(index);
return;
}
}
handlePillClick(index, e.currentTarget as HTMLElement);
}}
>
{model.displayName}
@@ -216,7 +215,7 @@ export default function ModelSelector({
</div>
{!(atMax && replacingIndex === null) && (
<Popover.Content side="top" align="end" width="lg">
<Popover.Content side="top" align="end" width="xl">
<ModelListContent
llmProviders={llmManager.llmProviders}
isLoading={llmManager.isLoadingProviders}

View File

@@ -425,16 +425,27 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [multiModel.isMultiModelActive]);
// Sync single-model selection to llmManager so the submission path
// uses the correct provider/version (replaces the old LLMPopover sync).
// Sync single-model selection to llmManager so the submission path uses
// the correct provider/version. Guard against echoing derived state back
// — only call updateCurrentLlm when the selection actually differs from
// currentLlm, otherwise the initial [] → [currentLlmModel] sync would
// pin `userHasManuallyOverriddenLLM=true` with whatever was resolved
// first (often the default model before the session's alt_model loads).
useEffect(() => {
if (multiModel.selectedModels.length === 1) {
const model = multiModel.selectedModels[0]!;
llmManager.updateCurrentLlm({
name: model.name,
provider: model.provider,
modelName: model.modelName,
});
const current = llmManager.currentLlm;
if (
model.provider !== current.provider ||
model.modelName !== current.modelName ||
model.name !== current.name
) {
llmManager.updateCurrentLlm({
name: model.name,
provider: model.provider,
modelName: model.modelName,
});
}
}
}, [multiModel.selectedModels]);
@@ -511,7 +522,8 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit({
message: lastUserMsg.message,
currentMessageFiles: currentMessageFiles,
deepResearch: deepResearchEnabledForCurrentWorkflow,
deepResearch:
deepResearchEnabledForCurrentWorkflow && !multiModel.isMultiModelActive,
messageIdToResend: lastUserMsg.messageId,
});
}, [
@@ -519,6 +531,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit,
currentMessageFiles,
deepResearchEnabledForCurrentWorkflow,
multiModel.isMultiModelActive,
]);
const toggleDocumentSidebar = useCallback(() => {
@@ -542,7 +555,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit({
message,
currentMessageFiles,
deepResearch: deepResearchEnabledForCurrentWorkflow,
deepResearch:
deepResearchEnabledForCurrentWorkflow &&
!multiModel.isMultiModelActive,
selectedModels: multiModel.isMultiModelActive
? multiModel.selectedModels
: undefined,
@@ -596,7 +611,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit({
message,
currentMessageFiles,
deepResearch: deepResearchEnabledForCurrentWorkflow,
deepResearch:
deepResearchEnabledForCurrentWorkflow &&
!multiModel.isMultiModelActive,
selectedModels: multiModel.isMultiModelActive
? multiModel.selectedModels
: undefined,
@@ -871,15 +888,20 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
agent={liveAgent}
isDefaultAgent={isDefaultAgent}
/>
{liveAgent && !llmManager.isLoadingProviders && (
<ModelSelector
llmManager={llmManager}
selectedModels={multiModel.selectedModels}
onAdd={multiModel.addModel}
onRemove={multiModel.removeModel}
onReplace={multiModel.replaceModel}
/>
)}
{!isSearch &&
!(
state.phase === "idle" && state.appMode === "search"
) &&
liveAgent &&
!llmManager.isLoadingProviders && (
<ModelSelector
llmManager={llmManager}
selectedModels={multiModel.selectedModels}
onAdd={multiModel.addModel}
onRemove={multiModel.removeModel}
onReplace={multiModel.replaceModel}
/>
)}
</Section>
<Spacer rem={1.5} />
</Fade>
@@ -964,6 +986,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
deepResearchEnabledForCurrentWorkflow
}
toggleDeepResearch={toggleDeepResearch}
isMultiModelActive={multiModel.isMultiModelActive}
filterManager={filterManager}
llmManager={llmManager}
initialMessage={

View File

@@ -1,9 +1,9 @@
"use client";
import { markdown } from "@opal/utils";
import React, { useCallback, useRef, useState } from "react";
import React, { useCallback, useEffect, useRef, useState } from "react";
import { useRouter } from "next/navigation";
import { Formik, Form, useFormikContext } from "formik";
import { Formik, Form } from "formik";
import useSWR, { mutate } from "swr";
import { SWR_KEYS } from "@/lib/swr-keys";
import { errorHandlingFetcher } from "@/lib/fetcher";
@@ -14,10 +14,9 @@ import Card from "@/refresh-components/cards/Card";
import Separator from "@/refresh-components/Separator";
import SimpleCollapsible from "@/refresh-components/SimpleCollapsible";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
import SwitchField from "@/refresh-components/form/SwitchField";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import InputTextAreaField from "@/refresh-components/form/InputTextAreaField";
import InputSelectField from "@/refresh-components/form/InputSelectField";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import {
SvgAddLines,
@@ -57,7 +56,6 @@ import * as ActionsLayouts from "@/layouts/actions-layouts";
import { getActionIcon } from "@/lib/tools/mcpUtils";
import { Disabled, Hoverable } from "@opal/core";
import IconButton from "@/refresh-components/buttons/IconButton";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import useFilter from "@/hooks/useFilter";
import { MCPServer } from "@/lib/tools/interfaces";
import type { IconProps } from "@opal/types";
@@ -70,26 +68,6 @@ interface DefaultAgentConfiguration {
default_system_prompt: string;
}
interface ChatPreferencesFormValues {
// Features
search_ui_enabled: boolean;
deep_research_enabled: boolean;
auto_scroll: boolean;
// Team context
company_name: string;
company_description: string;
// Advanced
maximum_chat_retention_days: string;
anonymous_user_enabled: boolean;
disable_default_assistant: boolean;
// File limits
user_file_max_upload_size_mb: string;
file_token_count_threshold_k: string;
}
interface MCPServerCardTool {
id: number;
icon: React.FunctionComponent<IconProps>;
@@ -198,6 +176,7 @@ type FileLimitFieldName =
interface NumericLimitFieldProps {
name: FileLimitFieldName;
initialValue: string;
defaultValue: string;
saveSettings: (updates: Partial<Settings>) => Promise<void>;
maxValue?: number;
@@ -206,16 +185,15 @@ interface NumericLimitFieldProps {
function NumericLimitField({
name,
initialValue: initialValueProp,
defaultValue,
saveSettings,
maxValue,
allowZero = false,
}: NumericLimitFieldProps) {
const { values, setFieldValue } =
useFormikContext<ChatPreferencesFormValues>();
const initialValue = useRef(values[name]);
const [value, setValue] = useState(initialValueProp);
const savedValue = useRef(initialValueProp);
const restoringRef = useRef(false);
const value = values[name];
const parsed = parseInt(value, 10);
const isOverMax =
@@ -223,8 +201,8 @@ function NumericLimitField({
const handleRestore = () => {
restoringRef.current = true;
initialValue.current = defaultValue;
void setFieldValue(name, defaultValue);
savedValue.current = defaultValue;
setValue(defaultValue);
void saveSettings({ [name]: parseInt(defaultValue, 10) });
};
@@ -242,11 +220,11 @@ function NumericLimitField({
if (!isValid) {
if (allowZero) {
// Empty/invalid means "no limit" — persist 0 and clear the field.
void setFieldValue(name, "");
setValue("");
void saveSettings({ [name]: 0 });
initialValue.current = "";
savedValue.current = "";
} else {
void setFieldValue(name, initialValue.current);
setValue(savedValue.current);
}
return;
}
@@ -259,10 +237,10 @@ function NumericLimitField({
// For allowZero fields, 0 means "no limit" — clear the display
// so the "No limit" placeholder is visible, but still persist 0.
if (allowZero && parsed === 0) {
void setFieldValue(name, "");
if (initialValue.current !== "") {
setValue("");
if (savedValue.current !== "") {
void saveSettings({ [name]: 0 });
initialValue.current = "";
savedValue.current = "";
}
return;
}
@@ -271,23 +249,24 @@ function NumericLimitField({
// Update the display to the canonical form (e.g. strip leading zeros).
if (value !== normalizedDisplay) {
void setFieldValue(name, normalizedDisplay);
setValue(normalizedDisplay);
}
// Persist only when the value actually changed.
if (normalizedDisplay !== initialValue.current) {
if (normalizedDisplay !== savedValue.current) {
void saveSettings({ [name]: parsed });
initialValue.current = normalizedDisplay;
savedValue.current = normalizedDisplay;
}
};
return (
<Hoverable.Root group="numericLimit" widthVariant="full">
<InputTypeInField
name={name}
<InputTypeIn
inputMode="numeric"
showClearButton={false}
pattern="[0-9]*"
value={value}
onChange={(e) => setValue(e.target.value)}
placeholder={allowZero ? "No limit" : `Default: ${defaultValue}`}
variant={isOverMax ? "error" : undefined}
rightSection={
@@ -311,14 +290,18 @@ function NumericLimitField({
interface FileSizeLimitFieldsProps {
saveSettings: (updates: Partial<Settings>) => Promise<void>;
initialUploadSizeMb: string;
defaultUploadSizeMb: string;
initialTokenThresholdK: string;
defaultTokenThresholdK: string;
maxAllowedUploadSizeMb?: number;
}
function FileSizeLimitFields({
saveSettings,
initialUploadSizeMb,
defaultUploadSizeMb,
initialTokenThresholdK,
defaultTokenThresholdK,
maxAllowedUploadSizeMb,
}: FileSizeLimitFieldsProps) {
@@ -336,6 +319,7 @@ function FileSizeLimitFields({
>
<NumericLimitField
name="user_file_max_upload_size_mb"
initialValue={initialUploadSizeMb}
defaultValue={defaultUploadSizeMb}
saveSettings={saveSettings}
maxValue={maxAllowedUploadSizeMb}
@@ -349,6 +333,7 @@ function FileSizeLimitFields({
>
<NumericLimitField
name="file_token_count_threshold_k"
initialValue={initialTokenThresholdK}
defaultValue={defaultTokenThresholdK}
saveSettings={saveSettings}
allowZero
@@ -359,18 +344,39 @@ function FileSizeLimitFields({
);
}
/**
* Inner form component that uses useFormikContext to access values
* and create save handlers for settings fields.
*/
function ChatPreferencesForm() {
const router = useRouter();
const settings = useSettingsContext();
const { values } = useFormikContext<ChatPreferencesFormValues>();
const s = settings.settings;
// Track initial text values to avoid unnecessary saves on blur
const initialCompanyName = useRef(values.company_name);
const initialCompanyDescription = useRef(values.company_description);
// Local state for text fields (save-on-blur)
const [companyName, setCompanyName] = useState(s.company_name ?? "");
const [companyDescription, setCompanyDescription] = useState(
s.company_description ?? ""
);
const savedCompanyName = useRef(companyName);
const savedCompanyDescription = useRef(companyDescription);
// Re-sync local state when settings change externally (e.g. another admin),
// but only when there's no in-progress edit (local matches last-saved value).
useEffect(() => {
const incoming = s.company_name ?? "";
if (companyName === savedCompanyName.current && incoming !== companyName) {
setCompanyName(incoming);
savedCompanyName.current = incoming;
}
}, [s.company_name]); // eslint-disable-line react-hooks/exhaustive-deps
useEffect(() => {
const incoming = s.company_description ?? "";
if (
companyDescription === savedCompanyDescription.current &&
incoming !== companyDescription
) {
setCompanyDescription(incoming);
savedCompanyDescription.current = incoming;
}
}, [s.company_description]); // eslint-disable-line react-hooks/exhaustive-deps
// Tools availability
const { tools: availableTools } = useAvailableTools();
@@ -526,16 +532,18 @@ function ChatPreferencesForm() {
<InputLayouts.Vertical
title="Team Name"
subDescription="This is added to all chat sessions as additional context to provide a richer/customized experience."
nonInteractive
>
<InputTypeInField
name="company_name"
<InputTypeIn
placeholder="Enter team name"
value={companyName}
onChange={(e) => setCompanyName(e.target.value)}
onBlur={() => {
if (values.company_name !== initialCompanyName.current) {
if (companyName !== savedCompanyName.current) {
void saveSettings({
company_name: values.company_name || null,
company_name: companyName || null,
});
initialCompanyName.current = values.company_name;
savedCompanyName.current = companyName;
}
}}
/>
@@ -544,23 +552,21 @@ function ChatPreferencesForm() {
<InputLayouts.Vertical
title="Team Context"
subDescription="Users can also provide additional individual context in their personal settings."
nonInteractive
>
<InputTextAreaField
name="company_description"
<InputTextArea
placeholder="Describe your team and how Onyx should behave."
rows={4}
maxRows={10}
autoResize
value={companyDescription}
onChange={(e) => setCompanyDescription(e.target.value)}
onBlur={() => {
if (
values.company_description !==
initialCompanyDescription.current
) {
if (companyDescription !== savedCompanyDescription.current) {
void saveSettings({
company_description: values.company_description || null,
company_description: companyDescription || null,
});
initialCompanyDescription.current =
values.company_description;
savedCompanyDescription.current = companyDescription;
}
}}
/>
@@ -604,9 +610,10 @@ function ChatPreferencesForm() {
title="Search Mode"
description="UI mode for quick document search across your organization."
disabled={uniqueSources.length === 0}
nonInteractive
>
<SwitchField
name="search_ui_enabled"
<Switch
checked={s.search_ui_enabled ?? false}
onCheckedChange={(checked) => {
void saveSettings({ search_ui_enabled: checked });
}}
@@ -616,12 +623,26 @@ function ChatPreferencesForm() {
</div>
</Disabled>
</SimpleTooltip>
<InputLayouts.Horizontal
title="Multi-Model Generation"
tag={{ title: "beta", color: "blue" }}
description="Allow multiple models to generate responses in parallel in chat."
nonInteractive
>
<Switch
checked={s.multi_model_chat_enabled ?? true}
onCheckedChange={(checked) => {
void saveSettings({ multi_model_chat_enabled: checked });
}}
/>
</InputLayouts.Horizontal>
<InputLayouts.Horizontal
title="Deep Research"
description="Agentic research system that works across the web and connected sources. Uses significantly more tokens per query."
nonInteractive
>
<SwitchField
name="deep_research_enabled"
<Switch
checked={s.deep_research_enabled ?? true}
onCheckedChange={(checked) => {
void saveSettings({ deep_research_enabled: checked });
}}
@@ -630,9 +651,10 @@ function ChatPreferencesForm() {
<InputLayouts.Horizontal
title="Chat Auto-Scroll"
description="Automatically scroll to new content as chat generates response. Users can override this in their personal settings."
nonInteractive
>
<SwitchField
name="auto_scroll"
<Switch
checked={s.auto_scroll ?? false}
onCheckedChange={(checked) => {
void saveSettings({ auto_scroll: checked });
}}
@@ -643,7 +665,7 @@ function ChatPreferencesForm() {
<Separator noPadding />
<Disabled disabled={values.disable_default_assistant}>
<Disabled disabled={s.disable_default_assistant ?? false}>
<div>
<Section gap={1.5}>
{/* Connectors */}
@@ -873,9 +895,12 @@ function ChatPreferencesForm() {
<InputLayouts.Horizontal
title="Keep Chat History"
description="Specify how long Onyx should retain chats in your organization."
nonInteractive
>
<InputSelectField
name="maximum_chat_retention_days"
<InputSelect
value={
s.maximum_chat_retention_days?.toString() ?? "forever"
}
onValueChange={(value) => {
void saveSettings({
maximum_chat_retention_days:
@@ -895,7 +920,7 @@ function ChatPreferencesForm() {
365 days
</InputSelect.Item>
</InputSelect.Content>
</InputSelectField>
</InputSelect>
</InputLayouts.Horizontal>
</Card>
@@ -906,17 +931,29 @@ function ChatPreferencesForm() {
>
<FileSizeLimitFields
saveSettings={saveSettings}
initialUploadSizeMb={
(s.user_file_max_upload_size_mb ?? 0) <= 0
? s.default_user_file_max_upload_size_mb?.toString() ??
"100"
: s.user_file_max_upload_size_mb!.toString()
}
defaultUploadSizeMb={
settings?.settings.default_user_file_max_upload_size_mb?.toString() ??
s.default_user_file_max_upload_size_mb?.toString() ??
"100"
}
initialTokenThresholdK={
s.file_token_count_threshold_k == null
? s.default_file_token_count_threshold_k?.toString() ??
"200"
: s.file_token_count_threshold_k === 0
? ""
: s.file_token_count_threshold_k.toString()
}
defaultTokenThresholdK={
settings?.settings.default_file_token_count_threshold_k?.toString() ??
s.default_file_token_count_threshold_k?.toString() ??
"200"
}
maxAllowedUploadSizeMb={
settings?.settings.max_allowed_upload_size_mb
}
maxAllowedUploadSizeMb={s.max_allowed_upload_size_mb}
/>
</InputLayouts.Vertical>
</Card>
@@ -925,9 +962,10 @@ function ChatPreferencesForm() {
<InputLayouts.Horizontal
title="Allow Anonymous Users"
description="Allow anyone to start chats without logging in. They do not see any other chats and cannot create agents or update settings."
nonInteractive
>
<SwitchField
name="anonymous_user_enabled"
<Switch
checked={s.anonymous_user_enabled ?? false}
onCheckedChange={(checked) => {
void saveSettings({ anonymous_user_enabled: checked });
}}
@@ -937,9 +975,11 @@ function ChatPreferencesForm() {
<InputLayouts.Horizontal
title="Always Start with an Agent"
description="This removes the default chat. Users will always start in an agent, and new chats will be created in their last active agent. Set featured agents to help new users get started."
nonInteractive
>
<SwitchField
name="disable_default_assistant"
<Switch
id="disable_default_assistant"
checked={s.disable_default_assistant ?? false}
onCheckedChange={(checked) => {
void saveSettings({
disable_default_assistant: checked,
@@ -1042,50 +1082,5 @@ function ChatPreferencesForm() {
}
export default function ChatPreferencesPage() {
const settings = useSettingsContext();
const initialValues: ChatPreferencesFormValues = {
// Features
search_ui_enabled: settings.settings.search_ui_enabled ?? false,
deep_research_enabled: settings.settings.deep_research_enabled ?? true,
auto_scroll: settings.settings.auto_scroll ?? false,
// Team context
company_name: settings.settings.company_name ?? "",
company_description: settings.settings.company_description ?? "",
// Advanced
maximum_chat_retention_days:
settings.settings.maximum_chat_retention_days?.toString() ?? "forever",
anonymous_user_enabled: settings.settings.anonymous_user_enabled ?? false,
disable_default_assistant:
settings.settings.disable_default_assistant ?? false,
// File limits — for upload size: 0/null means "use default";
// for token threshold: null means "use default", 0 means "no limit".
user_file_max_upload_size_mb:
(settings.settings.user_file_max_upload_size_mb ?? 0) <= 0
? settings.settings.default_user_file_max_upload_size_mb?.toString() ??
"100"
: settings.settings.user_file_max_upload_size_mb!.toString(),
file_token_count_threshold_k:
settings.settings.file_token_count_threshold_k == null
? settings.settings.default_file_token_count_threshold_k?.toString() ??
"200"
: settings.settings.file_token_count_threshold_k === 0
? ""
: settings.settings.file_token_count_threshold_k.toString(),
};
return (
<Formik
initialValues={initialValues}
onSubmit={() => {}}
enableReinitialize
>
<Form className="h-full w-full">
<ChatPreferencesForm />
</Form>
</Formik>
);
return <ChatPreferencesForm />;
}

View File

@@ -184,20 +184,18 @@ export function FileCard({
}
>
<div className="min-w-0 max-w-[12rem]">
<Interactive.Container border heightVariant="fit">
<div className="[&_.opal-content-md-title-row]:min-w-0 [&_.opal-content-md-title]:break-all">
<AttachmentItemLayout
icon={isProcessing ? SimpleLoader : SvgFileText}
title={file.name}
description={
isProcessing
? file.status === UserFileStatus.UPLOADING
? "Uploading..."
: "Processing..."
: typeLabel
}
/>
</div>
<Interactive.Container border heightVariant="fit" widthVariant="full">
<AttachmentItemLayout
icon={isProcessing ? SimpleLoader : SvgFileText}
title={file.name}
description={
isProcessing
? file.status === UserFileStatus.UPLOADING
? "Uploading..."
: "Processing..."
: typeLabel
}
/>
<Spacer horizontal rem={0.5} />
</Interactive.Container>
</div>

View File

@@ -213,9 +213,12 @@ const ChatScrollContainer = React.memo(
}
}, [updateScrollState, getScrollState]);
// Watch for content changes (MutationObserver + ResizeObserver)
// MutationObserver (structural) + ResizeObserver (height growth).
// NOT characterData — typewriter reveals don't change scrollHeight
// and firing per-char thrashed auto-scroll.
useEffect(() => {
const container = scrollContainerRef.current;
const contentWrapper = contentWrapperRef.current;
if (!container) return;
let rafId: number | null = null;
@@ -244,17 +247,17 @@ const ChatScrollContainer = React.memo(
});
};
// MutationObserver for content changes
const mutationObserver = new MutationObserver(onContentChange);
mutationObserver.observe(container, {
childList: true,
subtree: true,
characterData: true,
});
// ResizeObserver for container size changes
const resizeObserver = new ResizeObserver(onContentChange);
resizeObserver.observe(container);
if (contentWrapper) {
resizeObserver.observe(contentWrapper);
}
return () => {
mutationObserver.disconnect();

View File

@@ -331,10 +331,13 @@ const ChatUI = React.memo(
return null;
})}
{/* Error banner when last message is user message or error type */}
{/* Error banner when last message is user message or error type.
Skip for multi-model per-panel errors — those are shown in
their own panel, not as a global banner. */}
{(((error !== null || loadError !== null) &&
messages[messages.length - 1]?.type === "user") ||
messages[messages.length - 1]?.type === "error") && (
(messages[messages.length - 1]?.type === "error" &&
!messages[messages.length - 1]?.modelDisplayName)) && (
<div className={`p-4 w-full ${MSG_MAX_W} self-center`}>
<ErrorBanner
resubmit={onResubmit}

View File

@@ -86,6 +86,7 @@ export interface AppInputBarProps {
deepResearchEnabled: boolean;
setPresentingDocument?: (document: MinimalOnyxDocument) => void;
toggleDeepResearch: () => void;
isMultiModelActive?: boolean;
disabled: boolean;
ref?: React.Ref<AppInputBarHandle>;
// Side panel tab reading
@@ -109,6 +110,7 @@ const AppInputBar = React.memo(
llmManager,
deepResearchEnabled,
toggleDeepResearch,
isMultiModelActive,
setPresentingDocument,
disabled,
ref,
@@ -554,12 +556,17 @@ const AppInputBar = React.memo(
) : (
showDeepResearch && (
<SelectButton
disabled={disabled}
disabled={disabled || isMultiModelActive}
variant="select-light"
icon={SvgHourglass}
onClick={toggleDeepResearch}
state={deepResearchEnabled ? "selected" : "empty"}
foldable={!deepResearchEnabled}
tooltip={
isMultiModelActive
? "Deep Research is disabled in multi-model mode"
: undefined
}
>
Deep Research
</SelectButton>

View File

@@ -18,6 +18,7 @@ import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
mergeFetchedModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
@@ -119,7 +120,13 @@ function BedrockModalInternals({
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
formikProps.setFieldValue(
"model_configurations",
mergeFetchedModelConfigurations(
models,
formikProps.values.model_configurations
)
);
};
return (

View File

@@ -14,6 +14,7 @@ import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
mergeFetchedModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
@@ -50,12 +51,18 @@ function BifrostModalInternals({
const { models, error } = await fetchBifrostModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key || undefined,
provider_name: LLMProviderName.BIFROST,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
formikProps.setFieldValue(
"model_configurations",
mergeFetchedModelConfigurations(
models,
formikProps.values.model_configurations
)
);
};
return (

View File

@@ -12,6 +12,7 @@ import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues as BaseLLMModalValues,
mergeFetchedModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
@@ -61,7 +62,13 @@ function LMStudioModalInternals({
if (data.error) {
throw new Error(data.error);
}
formikProps.setFieldValue("model_configurations", data.models);
formikProps.setFieldValue(
"model_configurations",
mergeFetchedModelConfigurations(
data.models,
formikProps.values.model_configurations
)
);
};
return (

View File

@@ -13,6 +13,7 @@ import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
mergeFetchedModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
@@ -52,12 +53,18 @@ function LiteLLMProxyModalInternals({
const { models, error } = await fetchLiteLLMProxyModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key,
provider_name: LLMProviderName.LITELLM_PROXY,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
formikProps.setFieldValue(
"model_configurations",
mergeFetchedModelConfigurations(
models,
formikProps.values.model_configurations
)
);
};
return (

View File

@@ -15,6 +15,7 @@ import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
mergeFetchedModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
@@ -83,7 +84,13 @@ function OllamaModalInternals({
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
formikProps.setFieldValue(
"model_configurations",
mergeFetchedModelConfigurations(
models,
formikProps.values.model_configurations
)
);
};
return (

View File

@@ -14,6 +14,7 @@ import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
mergeFetchedModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
@@ -55,7 +56,13 @@ function OpenAICompatibleModalInternals({
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
formikProps.setFieldValue(
"model_configurations",
mergeFetchedModelConfigurations(
models,
formikProps.values.model_configurations
)
);
};
return (

View File

@@ -13,6 +13,7 @@ import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
mergeFetchedModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
@@ -49,15 +50,21 @@ function OpenRouterModalInternals({
!formikProps.values.api_base || !formikProps.values.api_key;
const handleFetchModels = async () => {
const { models, error } = await fetchOpenRouterModels({
const { models: fetched, error } = await fetchOpenRouterModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key,
provider_name: LLMProviderName.OPENROUTER,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
formikProps.setFieldValue(
"model_configurations",
mergeFetchedModelConfigurations(
fetched,
formikProps.values.model_configurations
)
);
};
return (

View File

@@ -123,6 +123,30 @@ export interface BaseLLMFormValues {
custom_config?: Record<string, string>;
}
// ─── mergeFetchedModelConfigurations ──────────────────────────────────────
/**
* Merges a freshly-fetched model list with the current form state so that
* refreshing the model list does not clobber the user's selections.
*
* - If the form has no models yet (first fetch / onboarding), the fetched
* list is returned as-is so each provider's own default `is_visible` applies.
* - Otherwise, models that already exist in the form keep their prior
* `is_visible` value, and newly-discovered models are added unselected so
* the user can opt-in explicitly.
*/
export function mergeFetchedModelConfigurations(
fetched: ModelConfiguration[],
existing: ModelConfiguration[]
): ModelConfiguration[] {
if (existing.length === 0) return fetched;
const priorByName = new Map(existing.map((m) => [m.name, m]));
return fetched.map((model) => {
const prior = priorByName.get(model.name);
return { ...model, is_visible: prior ? prior.is_visible : false };
});
}
// ─── Misc ─────────────────────────────────────────────────────────────────
export type TestApiKeyResult =