Compare commits

..

29 Commits

Author SHA1 Message Date
Nikolas Garza
2425bd4d8d feat(groups): add shared resources and token limit sections (#9538) 2026-03-24 23:44:44 +00:00
Raunak Bhagat
333b2b19cb refactor: fix sidebar layout (#9601)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-24 23:22:00 +00:00
Jamison Lahman
44895b3bd6 fix(ux): disable MCP Tools toggle if needs authenticated (#9607) 2026-03-24 22:45:23 +00:00
Raunak Bhagat
78c2ecf99f refactor(opal): restructure Onyx logo icons into composable parts (#9606)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-24 22:26:28 +00:00
Ciaran Sweet
e3e0e04edc fix: update values.yaml comment for opensearch admin password secretKeyRef (#9595) 2026-03-24 21:54:03 +00:00
Justin Tahara
a19fe03bd8 fix(ui): Text focused paste from PowerPoint (#9603) 2026-03-24 21:23:58 +00:00
Nikolas Garza
415c05b5f8 feat(groups): add create group page (#9515)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-24 20:55:18 +00:00
Nikolas Garza
352fd19f0a feat(admin): inline group renaming (#9491) 2026-03-24 20:12:17 +00:00
Raunak Bhagat
41ae039bfa refactor(opal): cleanup button types in Opal (#9598) 2026-03-24 20:06:39 +00:00
Bo-Onyx
782c734287 feat(hook): integrate query processing hook point (#9533) 2026-03-24 19:47:17 +00:00
Justin Tahara
728cdb0715 feat(helm): Adding pginto specific host (#9600) 2026-03-24 19:31:02 +00:00
Justin Tahara
baf6437117 fix(mt): Preprovision all tenants at once (#9576) 2026-03-24 19:13:10 +00:00
Raunak Bhagat
f187165077 refactor(opal): opalify FilterButton + migrate all instances away from old one (#9597) 2026-03-24 19:12:00 +00:00
Evan Lohn
727be3d663 fix: eager load chat session persona (#9577) 2026-03-24 19:03:57 +00:00
Jessica Singh
98c8f9884b fix(voice): add WebSocket upgrade headers to nginx configs (#9558)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-24 18:53:39 +00:00
Jamison Lahman
d79a068984 fix(fe): settings page layout shift on load (#9594) 2026-03-24 18:15:54 +00:00
Wenxi
ba0740d15f fix(fe): map snake_case auth type API response to camelCase (#9586)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-24 18:01:51 +00:00
Jamison Lahman
86b7bed90b chore(gha): basic test selection for external deps and connector tests (#9596)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-24 10:52:40 -07:00
Jamison Lahman
aead6ab9a5 fix(fe): properly style "Sign In" button (#9480) 2026-03-24 10:02:02 -07:00
Jamison Lahman
c9d4c186dd chore(blame): ignore ruff formatting change (#9345) 2026-03-24 10:01:23 -07:00
Jamison Lahman
70aad1ec46 fix(fe): editing an LLM provider uses the global default model (#9502) 2026-03-24 10:01:07 -07:00
Wenxi
ca3cc16ead fix(fe): stop SWR retry spam and spurious logout on auth pages (#9587) 2026-03-24 16:53:56 +00:00
Jamison Lahman
9ea1780ce5 chore(fe): memory input defaults to 1 row with max of 3 (#9563) 2026-03-24 09:35:47 -07:00
Jamison Lahman
f70e5e605e feat(ux): handle when chat session id cannot be found (#9524) 2026-03-24 16:18:52 +00:00
Jamison Lahman
84b134e226 fix(a11y): hidden buttons appear on tabbing (#9518) 2026-03-24 16:18:11 +00:00
Raunak Bhagat
b17c63a7d6 feat(admin): refresh agents page with DataTable and opal components (#9376) 2026-03-24 16:04:10 +00:00
Jamison Lahman
76c41d1b0b chore(fe): always load an empty memory card (#9560) 2026-03-24 15:47:12 +00:00
Jamison Lahman
579b86f1ce chore(fe): memories save after pressing enter (#9553)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-24 15:42:36 +00:00
Jamison Lahman
a53cf13db1 chore(fe): position memories modal at the top (#9554) 2026-03-24 15:42:32 +00:00
157 changed files with 4537 additions and 4025 deletions

View File

@@ -6,3 +6,4 @@
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190
7b927e79c25f4ddfd18a067f489e122acd2c89de # chore(format): format files where `ruff` and `black` agree (#9339)

View File

@@ -7,6 +7,15 @@ on:
merge_group:
pull_request:
branches: [main]
paths:
- "backend/**"
- "pyproject.toml"
- "uv.lock"
- ".github/workflows/pr-external-dependency-unit-tests.yml"
- ".github/actions/setup-python-and-install-dependencies/**"
- ".github/actions/setup-playwright/**"
- "deployment/docker_compose/docker-compose.yml"
- "deployment/docker_compose/docker-compose.dev.yml"
push:
tags:
- "v*.*.*"

View File

@@ -7,6 +7,13 @@ on:
merge_group:
pull_request:
branches: [main]
paths:
- "backend/**"
- "pyproject.toml"
- "uv.lock"
- ".github/workflows/pr-python-connector-tests.yml"
- ".github/actions/setup-python-and-install-dependencies/**"
- ".github/actions/setup-playwright/**"
push:
tags:
- "v*.*.*"

View File

@@ -1,36 +0,0 @@
"""add preferred_response_id and model_display_name to chat_message
Revision ID: a3f8b2c1d4e5
Create Date: 2026-03-22
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a3f8b2c1d4e5"
down_revision = "b728689f45b1"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"preferred_response_id",
sa.Integer(),
sa.ForeignKey("chat_message.id"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("model_display_name", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "model_display_name")
op.drop_column("chat_message", "preferred_response_id")

View File

@@ -25,10 +25,13 @@ from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
# Maximum tenants to provision in a single task run.
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
@shared_task(
@@ -85,9 +88,26 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
f"To provision: {tenants_to_provision}"
)
# just provision one tenant each time we run this ... increase if needed.
if tenants_to_provision > 0:
pre_provision_tenant()
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
if batch_size < tenants_to_provision:
task_logger.info(
f"Capping batch to {batch_size} "
f"(need {tenants_to_provision}, will catch up next cycle)"
)
provisioned = 0
for i in range(batch_size):
task_logger.info(f"Provisioning tenant {i + 1}/{batch_size}")
try:
if pre_provision_tenant():
provisioned += 1
except Exception:
task_logger.exception(
f"Failed to provision tenant {i + 1}/{batch_size}, "
"continuing with remaining tenants"
)
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
except Exception:
task_logger.exception("Error in check_available_tenants task")
@@ -101,11 +121,13 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
)
def pre_provision_tenant() -> None:
def pre_provision_tenant() -> bool:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.
This function fully sets up the tenant with all necessary configurations,
so it's ready to be assigned to a user immediately.
Returns True if a tenant was successfully provisioned, False otherwise.
"""
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
# rather than inside this function
@@ -118,10 +140,10 @@ def pre_provision_tenant() -> None:
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
if not lock_provision.acquire(blocking=False):
task_logger.debug(
"Skipping pre_provision_tenant task because it is already running"
task_logger.warning(
"Skipping pre_provision_tenant — could not acquire provision lock"
)
return
return False
tenant_id: str | None = None
try:
@@ -161,6 +183,7 @@ def pre_provision_tenant() -> None:
db_session.add(new_tenant)
db_session.commit()
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
return True
except Exception:
db_session.rollback()
task_logger.error(
@@ -184,6 +207,7 @@ def pre_provision_tenant() -> None:
asyncio.run(rollback_tenant_provisioning(tenant_id))
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
return False
finally:
try:
lock_provision.release()

View File

@@ -800,6 +800,33 @@ def update_user_group(
return db_user_group
def rename_user_group(
db_session: Session,
user_group_id: int,
new_name: str,
) -> UserGroup:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
if db_user_group is None:
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
_check_user_group_is_modifiable(db_user_group)
db_user_group.name = new_name
db_user_group.time_last_modified_by_user = func.now()
# CC pair documents in Vespa contain the group name, so we need to
# trigger a sync to update them with the new name.
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
if not DISABLE_VECTOR_DB:
db_user_group.is_up_to_date = False
db_session.commit()
return db_user_group
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)

View File

@@ -4,6 +4,7 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.persona import update_persona_access
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
from ee.onyx.db.user_group import fetch_user_group
@@ -11,13 +12,16 @@ from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from ee.onyx.db.user_group import rename_user_group
from ee.onyx.db.user_group import update_user_curator_relationship
from ee.onyx.db.user_group import update_user_group
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
from ee.onyx.server.user_group.models import UserGroup
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupRename
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
@@ -27,6 +31,9 @@ from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.db.persona import get_persona_by_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -87,6 +94,32 @@ def create_user_group(
return UserGroup.from_model(db_user_group)
@router.patch("/admin/user-group/rename")
def rename_user_group_endpoint(
rename_request: UserGroupRename,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
return UserGroup.from_model(
rename_user_group(
db_session=db_session,
user_group_id=rename_request.id,
new_name=rename_request.name,
)
)
except IntegrityError:
raise OnyxError(
OnyxErrorCode.DUPLICATE_RESOURCE,
f"User group with name '{rename_request.name}' already exists.",
)
except ValueError as e:
msg = str(e)
if "not found" in msg.lower():
raise OnyxError(OnyxErrorCode.NOT_FOUND, msg)
raise OnyxError(OnyxErrorCode.CONFLICT, msg)
@router.patch("/admin/user-group/{user_group_id}")
def patch_user_group(
user_group_id: int,
@@ -161,3 +194,38 @@ def delete_user_group(
user_group = fetch_user_group(db_session, user_group_id)
if user_group:
db_delete_user_group(db_session, user_group)
@router.patch("/admin/user-group/{user_group_id}/agents")
def update_group_agents(
user_group_id: int,
request: UpdateGroupAgentsRequest,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
for agent_id in request.added_agent_ids:
persona = get_persona_by_id(
persona_id=agent_id, user=user, db_session=db_session
)
current_group_ids = [g.id for g in persona.groups]
if user_group_id not in current_group_ids:
update_persona_access(
persona_id=agent_id,
creator_user_id=user.id,
db_session=db_session,
group_ids=current_group_ids + [user_group_id],
)
for agent_id in request.removed_agent_ids:
persona = get_persona_by_id(
persona_id=agent_id, user=user, db_session=db_session
)
current_group_ids = [g.id for g in persona.groups]
update_persona_access(
persona_id=agent_id,
creator_user_id=user.id,
db_session=db_session,
group_ids=[gid for gid in current_group_ids if gid != user_group_id],
)
db_session.commit()

View File

@@ -104,6 +104,16 @@ class AddUsersToUserGroupRequest(BaseModel):
user_ids: list[UUID]
class UserGroupRename(BaseModel):
id: int
name: str
class SetCuratorRequest(BaseModel):
user_id: UUID
is_curator: bool
class UpdateGroupAgentsRequest(BaseModel):
added_agent_ids: list[int]
removed_agent_ids: list[int]

View File

@@ -8,7 +8,6 @@ from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import Packet
@@ -36,13 +35,7 @@ class CreateChatSessionID(BaseModel):
chat_session_id: UUID
AnswerStreamPart = (
Packet
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamingError
| CreateChatSessionID
)
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
AnswerStream = Iterator[AnswerStreamPart]

View File

@@ -4,11 +4,9 @@ An overview can be found in the README.md file in this directory.
"""
import io
import queue
import re
import traceback
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextvars import Token
from uuid import UUID
@@ -30,7 +28,6 @@ from onyx.chat.compression import calculate_total_history_tokens
from onyx.chat.compression import compress_chat_history
from onyx.chat.compression import find_summary_for_branch
from onyx.chat.compression import get_compression_params
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.chat.llm_loop import EmptyLLMResponseError
from onyx.chat.llm_loop import run_llm_loop
@@ -62,8 +59,7 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -73,29 +69,33 @@ from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingPayload
from onyx.hooks.points.query_processing import QueryProcessingResponse
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.override_models import LLMOverride
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
from onyx.server.query_and_chat.models import MessageResponseIDInfo
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.tools.constants import SEARCH_TOOL_ID
@@ -433,6 +433,28 @@ def determine_search_params(
)
def _resolve_query_processing_hook_result(
hook_result: QueryProcessingResponse | HookSkipped | HookSoftFailed,
message_text: str,
) -> str:
"""Apply the Query Processing hook result to the message text.
Returns the (possibly rewritten) message text, or raises OnyxError with
QUERY_REJECTED if the hook signals rejection (query is null or empty).
HookSkipped and HookSoftFailed are pass-throughs — the original text is
returned unchanged.
"""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return message_text
if not (hook_result.query and hook_result.query.strip()):
raise OnyxError(
OnyxErrorCode.QUERY_REJECTED,
hook_result.rejection_message
or "The hook extension for query processing did not return a valid query. No rejection reason was provided.",
)
return hook_result.query.strip()
def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
@@ -483,16 +505,24 @@ def handle_stream_message_objects(
db_session=db_session,
)
yield CreateChatSessionID(chat_session_id=chat_session.id)
chat_session = get_chat_session_by_id(
chat_session_id=chat_session.id,
user_id=user_id,
db_session=db_session,
eager_load_persona=True,
)
else:
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
eager_load_persona=True,
)
persona = chat_session.persona
message_text = new_msg_req.message
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)
@@ -584,6 +614,28 @@ def handle_stream_message_objects(
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
# New message — run the Query Processing hook before saving to DB.
# Skipped on regeneration: the message already exists and was accepted previously.
# Skip the hook for empty/whitespace-only messages — no meaningful query
# to process, and SendMessageRequest.message has no min_length guard.
if message_text.strip():
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=QueryProcessingPayload(
query=message_text,
# Pass None for anonymous users or authenticated users without an email
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
# so None is accepted and serialised as null in both cases.
user_email=None if user.is_anonymous else user.email,
chat_session_id=str(chat_session.id),
).model_dump(),
response_type=QueryProcessingResponse,
)
message_text = _resolve_query_processing_hook_result(
hook_result, message_text
)
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
@@ -923,6 +975,17 @@ def handle_stream_message_objects(
state_container=state_container,
)
except OnyxError as e:
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
log_onyx_error(e)
yield StreamingError(
error=e.detail,
error_code=e.error_code.code,
is_retryable=e.status_code >= 500,
)
db_session.rollback()
return
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -1006,568 +1069,6 @@ def handle_stream_message_objects(
logger.exception("Error in setting processing status")
def _build_model_display_name(override: LLMOverride) -> str:
"""Build a human-readable display name from an LLM override."""
if override.display_name:
return override.display_name
if override.model_version:
return override.model_version
if override.model_provider:
return override.model_provider
return "unknown"
# Sentinel placed on the merged queue when a model thread finishes.
_MODEL_DONE = object()
class _ModelIndexEmitter(Emitter):
"""Emitter that tags packets with model_index and forwards directly to a shared queue.
Unlike the standard Emitter (which accumulates in a local bus), this puts
packets into the shared merged_queue in real-time as they're emitted. This
enables true parallel streaming — packets from multiple models interleave
on the wire instead of arriving in bursts after each model completes.
"""
def __init__(self, model_idx: int, merged_queue: queue.Queue) -> None:
super().__init__(queue.Queue()) # bus exists for compat, unused
self._model_idx = model_idx
self._merged_queue = merged_queue
def emit(self, packet: Packet) -> None:
tagged_placement = Placement(
turn_index=packet.placement.turn_index if packet.placement else 0,
tab_index=packet.placement.tab_index if packet.placement else 0,
sub_turn_index=(
packet.placement.sub_turn_index if packet.placement else None
),
model_index=self._model_idx,
)
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
self._merged_queue.put((self._model_idx, tagged_packet))
def run_multi_model_stream(
new_msg_req: SendMessageRequest,
user: User,
db_session: Session,
llm_overrides: list[LLMOverride],
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
mcp_headers: dict[str, str] | None = None,
) -> AnswerStream:
# TODO: The setup logic below (session resolution through tool construction)
# is duplicated from handle_stream_message_objects. Extract into a shared
# _ChatStreamContext dataclass + _prepare_chat_stream_context() factory so
# both paths call the same setup code. Tracked as follow-up refactor.
"""Run 2-3 LLMs in parallel and yield their packets tagged with model_index.
Resource management:
- Each model thread gets its OWN db_session (SQLAlchemy sessions are not thread-safe)
- The caller's db_session is used only for setup (before threads launch) and
completion callbacks (after threads finish)
- ThreadPoolExecutor is bounded to len(overrides) workers
- All threads are joined in the finally block regardless of success/failure
- Queue-based merging avoids busy-waiting
"""
n_models = len(llm_overrides)
if n_models < 2 or n_models > 3:
raise ValueError(f"Multi-model requires 2-3 overrides, got {n_models}")
if new_msg_req.deep_research:
raise ValueError("Multi-model is not supported with deep research")
tenant_id = get_current_tenant_id()
cache: CacheBackend | None = None
chat_session: ChatSession | None = None
user_id = user.id
if user.is_anonymous:
llm_user_identifier = "anonymous_user"
else:
llm_user_identifier = user.email or str(user_id)
try:
# ── Session setup (same as single-model path) ──────────────────
if not new_msg_req.chat_session_id:
if not new_msg_req.chat_session_info:
raise RuntimeError(
"Must specify a chat session id or chat session info"
)
chat_session = create_chat_session_from_request(
chat_session_request=new_msg_req.chat_session_info,
user_id=user_id,
db_session=db_session,
)
yield CreateChatSessionID(chat_session_id=chat_session.id)
else:
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
)
persona = chat_session.persona
message_text = new_msg_req.message
# ── Build N LLM instances and validate costs ───────────────────
llms: list[LLM] = []
model_display_names: list[str] = []
for override in llm_overrides:
llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=override,
additional_headers=litellm_additional_headers,
)
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=tenant_id,
llm_provider_api_key=llm.config.api_key,
)
llms.append(llm)
model_display_names.append(_build_model_display_name(override))
# Use first LLM for token counting (context window is checked per-model
# but token counting is model-agnostic enough for setup purposes)
token_counter = get_llm_token_counter(llms[0])
verify_user_files(
user_files=new_msg_req.file_descriptors,
user_id=user_id,
db_session=db_session,
project_id=chat_session.project_id,
)
# ── Chat history chain (shared across all models) ──────────────
chat_history = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
parent_message = chat_history[-1] if chat_history else root_message
elif (
new_msg_req.parent_message_id is None
or new_msg_req.parent_message_id == root_message.id
):
parent_message = root_message
chat_history = []
else:
parent_message = None
for i in range(len(chat_history) - 1, -1, -1):
if chat_history[i].id == new_msg_req.parent_message_id:
parent_message = chat_history[i]
chat_history = chat_history[: i + 1]
break
if parent_message is None:
raise ValueError(
"The new message sent is not on the latest mainline of messages"
)
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
message=message_text,
token_count=token_counter(message_text),
message_type=MessageType.USER,
files=new_msg_req.file_descriptors,
db_session=db_session,
commit=True,
)
chat_history.append(user_message)
available_files = _collect_available_file_ids(
chat_history=chat_history,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
summary_message = find_summary_for_branch(db_session, chat_history)
summarized_file_metadata: dict[str, FileToolMetadata] = {}
if summary_message and summary_message.last_summarized_message_id:
cutoff_id = summary_message.last_summarized_message_id
for msg in chat_history:
if msg.id > cutoff_id or not msg.files:
continue
for fd in msg.files:
file_id = fd.get("id")
if not file_id:
continue
summarized_file_metadata[file_id] = FileToolMetadata(
file_id=file_id,
filename=fd.get("name") or "unknown",
approx_char_count=0,
)
chat_history = [m for m in chat_history if m.id > cutoff_id]
user_memory_context = get_memories(user, db_session)
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
prompt_memory_context = (
user_memory_context
if user.use_memories
else user_memory_context.without_memories()
)
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
custom_agent_prompt or ""
)
reserved_token_count = calculate_reserved_tokens(
db_session=db_session,
persona_system_prompt=max_reserved_system_prompt_tokens_str,
token_counter=token_counter,
files=new_msg_req.file_descriptors,
user_memory_context=prompt_memory_context,
)
context_user_files = resolve_context_user_files(
persona=persona,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
# Use the smallest context window across all models for safety
min_context_window = min(llm.config.max_input_tokens for llm in llms)
extracted_context_files = extract_context_files(
user_files=context_user_files,
llm_max_context_window=min_context_window,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
search_params = determine_search_params(
persona_id=persona.id,
project_id=chat_session.project_id,
extracted_context_files=extracted_context_files,
)
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
search_tool_id = next(
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
None,
)
forced_tool_id = new_msg_req.forced_tool_id
if (
search_params.search_usage == SearchToolUsage.DISABLED
and forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
):
forced_tool_id = None
files = load_all_chat_files(chat_history, db_session)
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
# ── Reserve N assistant message IDs ────────────────────────────
reserved_messages = reserve_multi_model_message_ids(
db_session=db_session,
chat_session_id=chat_session.id,
parent_message_id=user_message.id,
model_display_names=model_display_names,
)
yield MultiModelMessageResponseIDInfo(
user_message_id=user_message.id,
reserved_assistant_message_ids=[m.id for m in reserved_messages],
model_names=model_display_names,
)
has_file_reader_tool = any(
tool.in_code_tool_id == "file_reader" for tool in all_tools
)
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
context_image_files=extracted_context_files.image_files,
additional_context=new_msg_req.additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
)
simple_chat_history = chat_history_result.simple_messages
all_injected_file_metadata: dict[str, FileToolMetadata] = (
chat_history_result.all_injected_file_metadata
if has_file_reader_tool
else {}
)
if summarized_file_metadata:
for fid, meta in summarized_file_metadata.items():
all_injected_file_metadata.setdefault(fid, meta)
if summary_message is not None:
summary_simple = ChatMessageSimple(
message=summary_message.message,
token_count=summary_message.token_count,
message_type=MessageType.ASSISTANT,
)
simple_chat_history.insert(0, summary_simple)
# ── Stop signal and processing status ──────────────────────────
cache = get_cache_backend()
reset_cancel_status(chat_session.id, cache)
def check_is_connected() -> bool:
return check_stop_signal(chat_session.id, cache)
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=True,
)
# Release the main session's read transaction before the long stream
db_session.commit()
# ── Parallel model execution ───────────────────────────────────
# Each model thread writes tagged packets to this shared queue.
# Sentinel _MODEL_DONE signals that a thread finished.
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = (
queue.Queue()
)
# Track per-model state containers for completion callbacks
state_containers: list[ChatStateContainer] = [
ChatStateContainer() for _ in range(n_models)
]
# Track which models completed successfully (for completion callbacks)
model_succeeded: list[bool] = [False] * n_models
user_identity = LLMUserIdentity(
user_id=llm_user_identifier,
session_id=str(chat_session.id),
)
def _run_model(model_idx: int) -> None:
"""Run a single model in a worker thread.
Uses _ModelIndexEmitter so packets flow directly to merged_queue
in real-time (not batched after completion). This enables true
parallel streaming where both models' tokens interleave on the wire.
DB access: tools may need a session during execution (e.g., search
tool). Each thread creates its own session via context manager.
"""
model_emitter = _ModelIndexEmitter(model_idx, merged_queue)
sc = state_containers[model_idx]
model_llm = llms[model_idx]
try:
# Each model thread gets its own DB session for tool execution.
# The session is scoped to the thread and closed when done.
with get_session_with_current_tenant() as thread_db_session:
# Construct tools per-thread with thread-local DB session
thread_tool_dict = construct_tools(
persona=persona,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id_filter=search_params.project_id_filter,
persona_id_filter=search_params.persona_id_filter,
bypass_acl=False,
enable_slack_search=_should_enable_slack_search(
persona, new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session.id,
message_id=user_message.id,
additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=available_files.user_file_ids,
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=search_params.search_usage,
)
model_tools: list[Tool] = []
for tool_list in thread_tool_dict.values():
model_tools.extend(tool_list)
# Run the LLM loop — this blocks until the model finishes.
# Packets flow to merged_queue in real-time via the emitter.
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=simple_chat_history,
tools=model_tools,
custom_agent_prompt=custom_agent_prompt,
context_files=extracted_context_files,
persona=persona,
user_memory_context=user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True
except Exception as e:
merged_queue.put((model_idx, e))
finally:
merged_queue.put((model_idx, _MODEL_DONE))
# Launch model threads via ThreadPoolExecutor (bounded, context-propagating)
executor = ThreadPoolExecutor(
max_workers=n_models,
thread_name_prefix="multi-model",
)
futures = []
try:
for i in range(n_models):
futures.append(executor.submit(_run_model, i))
# ── Main thread: merge and yield packets ───────────────────
models_remaining = n_models
while models_remaining > 0:
try:
model_idx, item = merged_queue.get(timeout=0.3)
except queue.Empty:
# Check cancellation during idle periods
if not check_is_connected():
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
return
continue
if item is _MODEL_DONE:
models_remaining -= 1
continue
if isinstance(item, Exception):
# Yield error as a tagged StreamingError packet
error_msg = str(item)
stack_trace = "".join(
traceback.format_exception(type(item), item, item.__traceback__)
)
# Redact API keys from error messages
model_llm = llms[model_idx]
if model_llm.config.api_key and len(model_llm.config.api_key) > 2:
error_msg = error_msg.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
model_llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(
error=error_msg,
stack_trace=stack_trace,
error_code="MODEL_ERROR",
is_retryable=True,
details={
"model": model_llm.config.model_name,
"provider": model_llm.config.model_provider,
"model_index": model_idx,
},
)
models_remaining -= 1
continue
if isinstance(item, Packet):
# Packet is already tagged with model_index by _ModelIndexEmitter
yield item
# ── Completion: save each successful model's response ──────
# Run completion callbacks on the main thread using the main
# session. This is safe because all worker threads have exited
# by this point (merged_queue fully drained).
for i in range(n_models):
if not model_succeeded[i]:
continue
try:
llm_loop_completion_handle(
state_container=state_containers[i],
is_connected=check_is_connected,
db_session=db_session,
assistant_message=reserved_messages[i],
llm=llms[i],
reserved_tokens=reserved_token_count,
)
except Exception:
logger.exception(
f"Failed completion for model {i} "
f"({model_display_names[i]})"
)
yield Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop", stop_reason="complete"),
)
finally:
# Ensure all threads are cleaned up regardless of how we exit
executor.shutdown(wait=True, cancel_futures=True)
except ValueError as e:
logger.exception("Failed to process multi-model chat message.")
yield StreamingError(
error=str(e),
error_code="VALIDATION_ERROR",
is_retryable=True,
)
db_session.rollback()
return
except Exception as e:
logger.exception(f"Failed multi-model chat: {e}")
stack_trace = traceback.format_exc()
yield StreamingError(
error=str(e),
stack_trace=stack_trace,
error_code="MULTI_MODEL_ERROR",
is_retryable=True,
)
db_session.rollback()
finally:
try:
if cache is not None and chat_session is not None:
set_processing_status(
chat_session_id=chat_session.id,
cache=cache,
value=False,
)
except Exception:
logger.exception("Error clearing processing status")
def llm_loop_completion_handle(
state_container: ChatStateContainer,
is_connected: Callable[[], bool],

View File

@@ -16,6 +16,7 @@ from sqlalchemy import Row
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
@@ -28,6 +29,7 @@ from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ChatSession
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
@@ -53,9 +55,19 @@ def get_chat_session_by_id(
db_session: Session,
include_deleted: bool = False,
is_shared: bool = False,
eager_load_persona: bool = False,
) -> ChatSession:
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
if eager_load_persona:
stmt = stmt.options(
joinedload(ChatSession.persona).options(
selectinload(Persona.tools),
selectinload(Persona.user_files),
),
joinedload(ChatSession.project),
)
if is_shared:
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
else:
@@ -602,79 +614,6 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15,
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Set the preferred assistant response for a multi-model user message.
Validates that the user message is a USER type and that the preferred
assistant message is a direct child of that user message.
"""
user_msg = db_session.query(ChatMessage).get(user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.query(ChatMessage).get(preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -897,8 +836,6 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -2645,15 +2645,6 @@ class ChatMessage(Base):
nullable=True,
)
# For multi-model turns: the user message points to which assistant response
# was selected as the preferred one to continue the conversation with.
preferred_response_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=True
)
# The display name of the model that generated this assistant message
model_display_name: Mapped[str | None] = mapped_column(String, nullable=True)
# What does this message contain
reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True)
message: Mapped[str] = mapped_column(Text)
@@ -2721,12 +2712,6 @@ class ChatMessage(Base):
remote_side="ChatMessage.id",
)
preferred_response: Mapped["ChatMessage | None"] = relationship(
"ChatMessage",
foreign_keys=[preferred_response_id],
remote_side="ChatMessage.id",
)
# Chat messages only need to know their immediate tool call children
# If there are nested tool calls, they are stored in the tool_call_children relationship.
tool_calls: Mapped[list["ToolCall"] | None] = relationship(

View File

@@ -44,6 +44,7 @@ class OnyxErrorCode(Enum):
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
QUERY_REJECTED = ("QUERY_REJECTED", 400)
# ------------------------------------------------------------------
# Not Found (404)

View File

@@ -5,6 +5,7 @@ Usage (Celery tasks and FastAPI handlers):
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
if isinstance(result, HookSkipped):
@@ -14,7 +15,7 @@ Usage (Celery tasks and FastAPI handlers):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is the response payload dict from the customer's endpoint
# result is a validated Pydantic model instance (spec.response_model)
...
is_reachable update policy
@@ -53,9 +54,11 @@ The executor uses three sessions:
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -81,6 +84,9 @@ class HookSoftFailed:
"""Hook was called but failed with SOFT fail strategy — continuing."""
T = TypeVar("T", bound=BaseModel)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
@@ -268,22 +274,21 @@ def _persist_result(
# ---------------------------------------------------------------------------
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
@@ -300,13 +305,36 @@ def execute_hook(
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(timeout=timeout) as client:
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
@@ -323,8 +351,41 @@ def execute_hook(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if outcome.response_payload is None:
raise ValueError(
f"response_payload is None for successful hook call (hook_id={hook_id})"
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return outcome.response_payload
return validated_model
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -51,13 +51,12 @@ class HookPointSpec:
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every concrete subclass declares all required class attributes.
"""Enforce that every subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
Raises TypeError at import time if any required attribute is missing or if
payload_model / response_model are not Pydantic BaseModel subclasses.
input_schema and output_schema are derived automatically from the models.
"""
super().__init_subclass__(**kwargs)
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]

View File

@@ -15,7 +15,7 @@ class QueryProcessingPayload(BaseModel):
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
)
@@ -25,7 +25,7 @@ class QueryProcessingResponse(BaseModel):
default=None,
description=(
"The query to use in the pipeline. "
"Null, empty string, or absent = reject the query."
"Null, empty string, whitespace-only, or absent = reject the query."
),
)
rejection_message: str | None = Field(

View File

@@ -11,7 +11,6 @@ class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
display_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@@ -29,7 +29,6 @@ from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.process_message import run_multi_model_stream
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
from onyx.configs.app_configs import WEB_DOMAIN
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -83,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -573,38 +570,6 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in run_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=chat_message_req.llm_overrides, # type: ignore[arg-type]
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -695,26 +660,6 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
_user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
try:
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -41,16 +41,6 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class MultiModelMessageResponseIDInfo(BaseModel):
"""Sent at the start of a multi-model streaming response.
Contains the user message ID and the reserved assistant message IDs
for each model being run in parallel."""
user_message_id: int | None
reserved_assistant_message_ids: list[int]
model_names: list[str]
class SourceTag(Tag):
source: DocumentSource
@@ -96,9 +86,6 @@ class SendMessageRequest(BaseModel):
message: str
llm_override: LLMOverride | None = None
# For multi-model mode: up to 3 LLM overrides to run in parallel.
# When provided with >1 entry, triggers multi-model streaming.
llm_overrides: list[LLMOverride] | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
@@ -224,8 +211,6 @@ class ChatMessageDetail(BaseModel):
error: str | None = None
current_feedback: str | None = None # "like" | "dislike" | null
processing_duration_seconds: float | None = None
preferred_response_id: int | None = None
model_display_name: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -233,11 +218,6 @@ class ChatMessageDetail(BaseModel):
return initial_dict
class SetPreferredResponseRequest(BaseModel):
user_message_id: int
preferred_response_id: int
class ChatSessionDetailResponse(BaseModel):
chat_session_id: UUID
description: str | None

View File

@@ -8,5 +8,3 @@ class Placement(BaseModel):
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -0,0 +1,47 @@
from sqlalchemy import inspect
from sqlalchemy.orm import Session
from onyx.db.chat import create_chat_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.models import Persona
from onyx.db.models import UserProject
from tests.external_dependency_unit.conftest import create_test_user
def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
"""Verify that eager_load_persona pre-loads persona, its collections, and project."""
user = create_test_user(db_session, "eager-load")
persona = Persona(name="eager-load-test", description="test")
project = UserProject(name="eager-load-project", user_id=user.id)
db_session.add_all([persona, project])
db_session.flush()
chat_session = create_chat_session(
db_session=db_session,
description="test",
user_id=None,
persona_id=persona.id,
project_id=project.id,
)
loaded = get_chat_session_by_id(
chat_session_id=chat_session.id,
user_id=None,
db_session=db_session,
eager_load_persona=True,
)
try:
tmp = inspect(loaded)
assert tmp is not None
unloaded = tmp.unloaded
assert "persona" not in unloaded
assert "project" not in unloaded
tmp = inspect(loaded.persona)
assert tmp is not None
persona_unloaded = tmp.unloaded
assert "tools" not in persona_unloaded
assert "user_files" not in persona_unloaded
finally:
db_session.rollback()

View File

@@ -143,8 +143,8 @@ def use_mock_search_pipeline(
db_session: Session | None = None, # noqa: ARG001
auto_detect_filters: bool = False, # noqa: ARG001
llm: LLM | None = None, # noqa: ARG001
project_id: int | None = None, # noqa: ARG001
persona_id: int | None = None, # noqa: ARG001
project_id_filter: int | None = None, # noqa: ARG001
persona_id_filter: int | None = None, # noqa: ARG001
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
acl_filters: list[str] | None = None, # noqa: ARG001
embedding_model: EmbeddingModel | None = None, # noqa: ARG001

View File

@@ -0,0 +1,53 @@
"""Tests for user group rename DB operation."""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from ee.onyx.db.user_group import rename_user_group
from onyx.db.models import UserGroup
class TestRenameUserGroup:
"""Tests for rename_user_group function."""
@patch("ee.onyx.db.user_group.DISABLE_VECTOR_DB", False)
@patch(
"ee.onyx.db.user_group._mark_user_group__cc_pair_relationships_outdated__no_commit"
)
def test_rename_succeeds_and_triggers_sync(
self, mock_mark_outdated: MagicMock
) -> None:
mock_session = MagicMock()
mock_group = MagicMock(spec=UserGroup)
mock_group.name = "Old Name"
mock_group.is_up_to_date = True
mock_session.scalar.return_value = mock_group
result = rename_user_group(mock_session, user_group_id=1, new_name="New Name")
assert result.name == "New Name"
assert result.is_up_to_date is False
mock_mark_outdated.assert_called_once()
mock_session.commit.assert_called_once()
def test_rename_group_not_found(self) -> None:
mock_session = MagicMock()
mock_session.scalar.return_value = None
with pytest.raises(ValueError, match="not found"):
rename_user_group(mock_session, user_group_id=999, new_name="New Name")
mock_session.commit.assert_not_called()
def test_rename_group_syncing_raises(self) -> None:
mock_session = MagicMock()
mock_group = MagicMock(spec=UserGroup)
mock_group.is_up_to_date = False
mock_session.scalar.return_value = mock_group
with pytest.raises(ValueError, match="currently syncing"):
rename_user_group(mock_session, user_group_id=1, new_name="New Name")
mock_session.commit.assert_not_called()

View File

@@ -0,0 +1,216 @@
"""
Unit tests for the check_available_tenants task.
Tests verify:
- Provisioning loop calls pre_provision_tenant the correct number of times
- Batch size is capped at _MAX_TENANTS_PER_RUN
- A failure in one provisioning call does not stop subsequent calls
- No provisioning happens when pool is already full
- TARGET_AVAILABLE_TENANTS is respected
"""
from unittest.mock import MagicMock
import pytest
from ee.onyx.background.celery.tasks.tenant_provisioning.tasks import (
_MAX_TENANTS_PER_RUN,
)
from ee.onyx.background.celery.tasks.tenant_provisioning.tasks import (
check_available_tenants,
)
# Access the underlying function directly, bypassing Celery's task wrapper
# which injects `self` as the first argument when bind=True.
_check_available_tenants = check_available_tenants.run
@pytest.fixture()
def _enable_multi_tenant(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.MULTI_TENANT",
True,
)
@pytest.fixture()
def mock_redis(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
mock_lock = MagicMock()
mock_lock.acquire.return_value = True
mock_client = MagicMock()
mock_client.lock.return_value = mock_lock
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.get_redis_client",
lambda tenant_id: mock_client, # noqa: ARG005
)
return mock_client
@pytest.fixture()
def mock_pre_provision(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
mock = MagicMock(return_value=True)
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.pre_provision_tenant",
mock,
)
return mock
def _mock_available_count(monkeypatch: pytest.MonkeyPatch, count: int) -> None:
"""Set up the DB session mock to return a specific available tenant count."""
mock_session = MagicMock()
mock_session.__enter__ = MagicMock(return_value=mock_session)
mock_session.__exit__ = MagicMock(return_value=False)
mock_session.query.return_value.count.return_value = count
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.get_session_with_shared_schema",
lambda: mock_session,
)
@pytest.mark.usefixtures("_enable_multi_tenant", "mock_redis")
class TestCheckAvailableTenants:
def test_provisions_all_needed_tenants(
self,
monkeypatch: pytest.MonkeyPatch,
mock_pre_provision: MagicMock,
) -> None:
"""When pool has 2 and target is 5, should provision 3."""
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
5,
)
_mock_available_count(monkeypatch, 2)
_check_available_tenants()
assert mock_pre_provision.call_count == 3
def test_batch_capped_at_max_per_run(
self,
monkeypatch: pytest.MonkeyPatch,
mock_pre_provision: MagicMock,
) -> None:
"""When pool needs more than _MAX_TENANTS_PER_RUN, cap the batch."""
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
20,
)
_mock_available_count(monkeypatch, 0)
_check_available_tenants()
assert mock_pre_provision.call_count == _MAX_TENANTS_PER_RUN
def test_no_provisioning_when_pool_full(
self,
monkeypatch: pytest.MonkeyPatch,
mock_pre_provision: MagicMock,
) -> None:
"""When pool already meets target, should not provision anything."""
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
5,
)
_mock_available_count(monkeypatch, 5)
_check_available_tenants()
assert mock_pre_provision.call_count == 0
def test_no_provisioning_when_pool_exceeds_target(
self,
monkeypatch: pytest.MonkeyPatch,
mock_pre_provision: MagicMock,
) -> None:
"""When pool exceeds target, should not provision anything."""
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
5,
)
_mock_available_count(monkeypatch, 8)
_check_available_tenants()
assert mock_pre_provision.call_count == 0
def test_failure_does_not_stop_remaining(
self,
monkeypatch: pytest.MonkeyPatch,
mock_pre_provision: MagicMock,
) -> None:
"""If one provisioning fails, the rest should still be attempted."""
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
5,
)
_mock_available_count(monkeypatch, 0)
# Fail on calls 2 and 4 (1-indexed)
call_count = 0
def side_effect() -> bool:
nonlocal call_count
call_count += 1
if call_count in (2, 4):
raise RuntimeError("provisioning failed")
return True
mock_pre_provision.side_effect = side_effect
_check_available_tenants()
# All 5 should be attempted despite 2 failures
assert mock_pre_provision.call_count == 5
def test_skips_when_not_multi_tenant(
self,
monkeypatch: pytest.MonkeyPatch,
mock_pre_provision: MagicMock,
) -> None:
"""Should not provision when multi-tenancy is disabled."""
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.MULTI_TENANT",
False,
)
_check_available_tenants()
assert mock_pre_provision.call_count == 0
def test_skips_when_lock_not_acquired(
self,
mock_redis: MagicMock,
mock_pre_provision: MagicMock,
) -> None:
"""Should skip when another instance holds the lock."""
mock_redis.lock.return_value.acquire.return_value = False
_check_available_tenants()
assert mock_pre_provision.call_count == 0
def test_lock_release_failure_does_not_raise(
self,
monkeypatch: pytest.MonkeyPatch,
mock_redis: MagicMock,
mock_pre_provision: MagicMock,
) -> None:
"""LockNotOwnedError on release should be caught, not propagated."""
from redis.exceptions import LockNotOwnedError
monkeypatch.setattr(
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.TARGET_AVAILABLE_TENANTS",
5,
)
_mock_available_count(monkeypatch, 4)
mock_redis.lock.return_value.release.side_effect = LockNotOwnedError("expired")
# Should not raise
_check_available_tenants()
assert mock_pre_provision.call_count == 1

View File

@@ -1,206 +0,0 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in run_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
from typing import Any
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
"""Advance the generator one step to trigger early validation."""
from onyx.chat.process_message import run_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = run_multi_model_stream(req, user, db, overrides)
# Calling next() executes until the first yield OR raises.
# Validation errors are raised before any yield.
next(gen)
# ---------------------------------------------------------------------------
# run_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_raises(self) -> None:
"""Exactly 1 override is not multi-model — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
def test_four_overrides_raises(self) -> None:
"""4 overrides exceeds maximum — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
def test_zero_overrides_raises(self) -> None:
"""Empty override list raises."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [])
def test_deep_research_raises(self) -> None:
"""deep_research=True is incompatible with multi-model."""
req = _make_request(deep_research=True)
with pytest.raises(ValueError, match="not supported"):
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
req = _make_request()
# 1 override must fail
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
try:
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
except ValueError as exc:
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
except Exception:
pass # Any other error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.query.return_value.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.query.return_value.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.query.return_value.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.query.return_value.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"

View File

@@ -1,134 +0,0 @@
"""Unit tests for multi-model answer generation types.
Tests cover:
- Placement.model_index serialization
- MultiModelMessageResponseIDInfo round-trip
- SendMessageRequest.llm_overrides backward compatibility
- ChatMessageDetail new fields
"""
from uuid import uuid4
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
class TestPlacementModelIndex:
def test_default_none(self) -> None:
p = Placement(turn_index=0)
assert p.model_index is None
def test_set_value(self) -> None:
p = Placement(turn_index=0, model_index=2)
assert p.model_index == 2
def test_serializes(self) -> None:
p = Placement(turn_index=0, tab_index=1, model_index=1)
d = p.model_dump()
assert d["model_index"] == 1
def test_none_excluded_when_default(self) -> None:
p = Placement(turn_index=0)
d = p.model_dump()
assert d["model_index"] is None
class TestMultiModelMessageResponseIDInfo:
def test_round_trip(self) -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=42,
reserved_assistant_message_ids=[43, 44, 45],
model_names=["gpt-4", "claude-opus", "gemini-pro"],
)
d = info.model_dump()
restored = MultiModelMessageResponseIDInfo(**d)
assert restored.user_message_id == 42
assert restored.reserved_assistant_message_ids == [43, 44, 45]
assert restored.model_names == ["gpt-4", "claude-opus", "gemini-pro"]
def test_null_user_message_id(self) -> None:
info = MultiModelMessageResponseIDInfo(
user_message_id=None,
reserved_assistant_message_ids=[1, 2],
model_names=["a", "b"],
)
assert info.user_message_id is None
class TestSendMessageRequestOverrides:
def test_llm_overrides_default_none(self) -> None:
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
)
assert req.llm_overrides is None
def test_llm_overrides_accepts_list(self) -> None:
overrides = [
LLMOverride(model_provider="openai", model_version="gpt-4"),
LLMOverride(model_provider="anthropic", model_version="claude-opus"),
]
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
llm_overrides=overrides,
)
assert req.llm_overrides is not None
assert len(req.llm_overrides) == 2
def test_backward_compat_single_override(self) -> None:
req = SendMessageRequest(
message="hello",
chat_session_id=uuid4(),
llm_override=LLMOverride(model_provider="openai", model_version="gpt-4"),
)
assert req.llm_override is not None
assert req.llm_overrides is None
class TestChatMessageDetailMultiModel:
def test_defaults_none(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.ASSISTANT,
time_sent="2026-03-22T00:00:00Z",
files=[],
)
assert detail.preferred_response_id is None
assert detail.model_display_name is None
def test_set_values(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.USER,
time_sent="2026-03-22T00:00:00Z",
files=[],
preferred_response_id=42,
model_display_name="GPT-4",
)
assert detail.preferred_response_id == 42
assert detail.model_display_name == "GPT-4"
def test_serializes(self) -> None:
from onyx.configs.constants import MessageType
detail = ChatMessageDetail(
message_id=1,
message="hello",
message_type=MessageType.ASSISTANT,
time_sent="2026-03-22T00:00:00Z",
files=[],
model_display_name="Claude Opus",
)
d = detail.model_dump()
assert d["model_display_name"] == "Claude Opus"
assert d["preferred_response_id"] is None

View File

@@ -1,4 +1,12 @@
import pytest
from onyx.chat.process_message import _resolve_query_processing_hook_result
from onyx.chat.process_message import remove_answer_citations
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
@@ -32,3 +40,81 @@ def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None
remove_answer_citations(answer)
== "See [reference](https://example.com/Function_(mathematics)) for context."
)
# ---------------------------------------------------------------------------
# Query Processing hook response handling (_resolve_query_processing_hook_result)
# ---------------------------------------------------------------------------
def test_hook_skipped_leaves_message_text_unchanged() -> None:
result = _resolve_query_processing_hook_result(HookSkipped(), "original query")
assert result == "original query"
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
result = _resolve_query_processing_hook_result(HookSoftFailed(), "original query")
assert result == "original query"
def test_null_query_raises_query_rejected() -> None:
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=None), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_empty_string_query_raises_query_rejected() -> None:
"""Empty string is falsy — must be treated as rejection, same as None."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=""), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_whitespace_only_query_raises_query_rejected() -> None:
"""Whitespace-only string is truthy but meaningless — must be treated as rejection."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=" "), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_absent_query_field_raises_query_rejected() -> None:
"""query defaults to None when not provided."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_rejection_message_surfaced_in_error_when_provided() -> None:
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(
query=None, rejection_message="Queries about X are not allowed."
),
"original query",
)
assert "Queries about X are not allowed." in str(exc_info.value)
def test_fallback_rejection_message_when_none() -> None:
"""No rejection_message → generic fallback used in OnyxError detail."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=None, rejection_message=None),
"original query",
)
assert "No rejection reason was provided." in str(exc_info.value)
def test_nonempty_query_rewrites_message_text() -> None:
result = _resolve_query_processing_hook_result(
QueryProcessingResponse(query="rewritten query"), "original query"
)
assert result == "rewritten query"

View File

@@ -7,6 +7,7 @@ from unittest.mock import patch
import httpx
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
@@ -15,13 +16,15 @@ from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
# A valid QueryProcessingResponse payload — used by success-path tests.
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
def _make_hook(
@@ -33,6 +36,7 @@ def _make_hook(
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
@@ -42,6 +46,7 @@ def _make_hook(
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
hook.hook_point = hook_point
return hook
@@ -140,6 +145,7 @@ def test_early_exit_returns_skipped_with_no_db_writes(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert isinstance(result, HookSkipped)
@@ -152,7 +158,9 @@ def test_early_exit_returns_skipped_with_no_db_writes(
# ---------------------------------------------------------------------------
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
def test_success_returns_validated_model_and_sets_reachable(
db_session: MagicMock,
) -> None:
hook = _make_hook()
with (
@@ -171,9 +179,11 @@ def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> No
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert result == _RESPONSE_PAYLOAD
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
@@ -200,9 +210,11 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert result == _RESPONSE_PAYLOAD
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
mock_update.assert_not_called()
@@ -230,6 +242,7 @@ def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert isinstance(result, HookSoftFailed)
@@ -265,6 +278,7 @@ def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert isinstance(result, HookSoftFailed)
@@ -388,6 +402,7 @@ def test_http_failure_paths(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
@@ -395,6 +410,7 @@ def test_http_failure_paths(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert isinstance(result, expected_type)
@@ -442,6 +458,7 @@ def test_authorization_header(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
_, call_kwargs = mock_client.post.call_args
@@ -457,16 +474,16 @@ def test_authorization_header(
@pytest.mark.parametrize(
"http_exception,expected_result",
"http_exception,expect_onyx_error",
[
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
pytest.param(None, False, id="success_path"),
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expected_result: Any,
expect_onyx_error: bool,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
@@ -489,12 +506,13 @@ def test_persist_session_failure_is_swallowed(
side_effect=http_exception,
)
if expected_result is OnyxError:
if expect_onyx_error:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
@@ -502,8 +520,131 @@ def test_persist_session_failure_is_swallowed(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert result == expected_result
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
# ---------------------------------------------------------------------------
# Response model validation
# ---------------------------------------------------------------------------
class _StrictResponse(BaseModel):
"""Strict model used to reliably trigger a ValidationError in tests."""
required_field: str # no default → missing key raises ValidationError
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
),
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
],
)
def test_response_validation_failure_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""A response that fails response_model validation is treated like any other
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
_setup_client(mock_client_cls, response=_make_response(json_return={}))
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=_StrictResponse,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=_StrictResponse,
)
assert isinstance(result, HookSoftFailed)
# is_reachable must not be updated — server responded correctly
mock_update.assert_not_called()
# failure must be logged
mock_log.assert_called_once()
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "validation" in (log_kwargs["error_message"] or "").lower()
# ---------------------------------------------------------------------------
# Outer soft-fail guard in execute_hook
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
],
)
def test_unexpected_exception_in_inner_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
if expected_type is HookSoftFailed:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert isinstance(result, HookSoftFailed)
else:
with pytest.raises(ValueError, match="unexpected bug"):
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
@@ -535,6 +676,7 @@ def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> Non
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
response_type=QueryProcessingResponse,
)
assert isinstance(result, HookSoftFailed)

View File

@@ -23,6 +23,12 @@ upstream web_server {
# Conditionally include MCP upstream configuration
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
map $http_upgrade $connection_upgrade {
default upgrade;
'' close;
}
server {
listen 80 default_server;
@@ -46,8 +52,10 @@ server {
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
# need to use 1.1 to support chunked transfers
# need to use 1.1 to support chunked transfers and WebSocket
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_buffering off;
# timeout settings

View File

@@ -23,6 +23,12 @@ upstream web_server {
# Conditionally include MCP upstream configuration
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
map $http_upgrade $connection_upgrade {
default upgrade;
'' close;
}
server {
listen 80 default_server;
@@ -47,8 +53,10 @@ server {
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
# need to use 1.1 to support chunked transfers
# need to use 1.1 to support chunked transfers and WebSocket
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_buffering off;
# we don't want nginx trying to do something clever with
@@ -92,6 +100,8 @@ server {
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_buffering off;
# we don't want nginx trying to do something clever with
# redirects, we set the Host: header above already.

View File

@@ -23,6 +23,12 @@ upstream web_server {
# Conditionally include MCP upstream configuration
include /etc/nginx/conf.d/mcp_upstream.conf.inc;
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
map $http_upgrade $connection_upgrade {
default upgrade;
'' close;
}
server {
listen 80 default_server;
@@ -47,8 +53,10 @@ server {
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
# need to use 1.1 to support chunked transfers
# need to use 1.1 to support chunked transfers and WebSocket
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_buffering off;
# timeout settings
@@ -106,6 +114,8 @@ server {
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_buffering off;
# timeout settings

View File

@@ -28,6 +28,12 @@ data:
}
{{- end }}
# WebSocket support: only set Connection "upgrade" for actual upgrade requests
map $http_upgrade $connection_upgrade {
default upgrade;
'' close;
}
server.conf: |
server {
listen 1024;
@@ -65,6 +71,8 @@ data:
proxy_set_header X-Forwarded-Host $host;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_buffering off;
proxy_redirect off;
# timeout settings

View File

@@ -10,7 +10,7 @@ data:
#!/usr/bin/env sh
set -eu
HOST="${POSTGRES_HOST:-localhost}"
HOST="${PGINTO_HOST:-${POSTGRES_HOST:-localhost}}"
PORT="${POSTGRES_PORT:-5432}"
USER="${POSTGRES_USER:-postgres}"
DB="${POSTGRES_DB:-postgres}"

View File

@@ -103,7 +103,7 @@ opensearch:
- name: OPENSEARCH_INITIAL_ADMIN_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-opensearch # Must match auth.opensearch.secretName.
name: onyx-opensearch # Must match auth.opensearch.secretName or auth.opensearch.existingSecret if defined.
key: opensearch_admin_password # Must match auth.opensearch.secretKeys value.
resources:
@@ -282,7 +282,7 @@ nginx:
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
# Workaround: Helm upgrade will restart if the following annotation value changes.
podAnnotations:
onyx.app/nginx-config-version: "1"
onyx.app/nginx-config-version: "2"
# Propagate DOMAIN into nginx so server_name continues to use the same env var
extraEnvs:

View File

@@ -83,6 +83,14 @@
"scope": [],
"rule": "Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments."
},
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."

View File

@@ -1,5 +1,9 @@
import "@opal/components/tooltip.css";
import { Interactive, type InteractiveStatelessProps } from "@opal/core";
import {
Disabled,
Interactive,
type InteractiveStatelessProps,
} from "@opal/core";
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
@@ -32,9 +36,6 @@ type ButtonProps = InteractiveStatelessProps &
*/
size?: ContainerSizeVariants;
/** HTML button type. When provided, Container renders a `<button>` element. */
type?: "submit" | "button" | "reset";
/** Tooltip text shown on hover. */
tooltip?: string;
@@ -43,6 +44,9 @@ type ButtonProps = InteractiveStatelessProps &
/** Which side the tooltip appears on. */
tooltipSide?: TooltipSide;
/** Wraps the button in a Disabled context. `false` overrides parent contexts. */
disabled?: boolean;
};
// ---------------------------------------------------------------------------
@@ -59,6 +63,7 @@ function Button({
tooltip,
tooltipSide = "top",
responsiveHideText = false,
disabled,
...interactiveProps
}: ButtonProps) {
const isLarge = size === "lg";
@@ -76,7 +81,7 @@ function Button({
) : null;
const button = (
<Interactive.Stateless {...interactiveProps}>
<Interactive.Stateless type={type} {...interactiveProps}>
<Interactive.Container
type={type}
border={interactiveProps.prominence === "secondary"}
@@ -102,9 +107,7 @@ function Button({
</Interactive.Stateless>
);
if (!tooltip) return button;
return (
const result = tooltip ? (
<TooltipPrimitive.Root>
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
<TooltipPrimitive.Portal>
@@ -117,7 +120,15 @@ function Button({
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
</TooltipPrimitive.Root>
) : (
button
);
if (disabled != null) {
return <Disabled disabled={disabled}>{result}</Disabled>;
}
return result;
}
export { Button, type ButtonProps };

View File

@@ -0,0 +1,8 @@
.opal-button-chevron {
transition: rotate 200ms ease;
}
.interactive[data-interaction="hover"] .opal-button-chevron,
.interactive[data-interaction="active"] .opal-button-chevron {
rotate: -180deg;
}

View File

@@ -0,0 +1,22 @@
import "@opal/components/buttons/chevron.css";
import type { IconProps } from "@opal/types";
import { SvgChevronDownSmall } from "@opal/icons";
import { cn } from "@opal/utils";
/**
* Chevron icon that rotates 180° when its parent `.interactive` enters
* hover / active state. Shared by OpenButton, FilterButton, and any
* future button that needs an animated dropdown indicator.
*
* Stable component identity — never causes React to remount the SVG.
*/
function ChevronIcon({ className, ...props }: IconProps) {
return (
<SvgChevronDownSmall
className={cn(className, "opal-button-chevron")}
{...props}
/>
);
}
export { ChevronIcon };

View File

@@ -0,0 +1,107 @@
import type { Meta, StoryObj } from "@storybook/react";
import { FilterButton } from "@opal/components";
import { Disabled as DisabledProvider } from "@opal/core";
import { SvgUser, SvgActions, SvgTag } from "@opal/icons";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
const meta: Meta<typeof FilterButton> = {
title: "opal/components/FilterButton",
component: FilterButton,
tags: ["autodocs"],
decorators: [
(Story) => (
<TooltipPrimitive.Provider>
<Story />
</TooltipPrimitive.Provider>
),
],
};
export default meta;
type Story = StoryObj<typeof FilterButton>;
export const Empty: Story = {
args: {
icon: SvgUser,
children: "Everyone",
},
};
export const Active: Story = {
args: {
icon: SvgUser,
active: true,
children: "By alice@example.com",
onClear: () => console.log("clear"),
},
};
export const Open: Story = {
args: {
icon: SvgActions,
interaction: "hover",
children: "All Actions",
},
};
export const ActiveOpen: Story = {
args: {
icon: SvgActions,
active: true,
interaction: "hover",
children: "2 selected",
onClear: () => console.log("clear"),
},
};
export const Disabled: Story = {
args: {
icon: SvgTag,
children: "All Tags",
},
decorators: [
(Story) => (
<DisabledProvider disabled>
<Story />
</DisabledProvider>
),
],
};
export const DisabledActive: Story = {
args: {
icon: SvgTag,
active: true,
children: "2 tags",
onClear: () => console.log("clear"),
},
decorators: [
(Story) => (
<DisabledProvider disabled>
<Story />
</DisabledProvider>
),
],
};
export const StateComparison: Story = {
render: () => (
<div style={{ display: "flex", gap: 12, alignItems: "center" }}>
<FilterButton icon={SvgUser} onClear={() => undefined}>
Everyone
</FilterButton>
<FilterButton icon={SvgUser} active onClear={() => console.log("clear")}>
By alice@example.com
</FilterButton>
</div>
),
};
export const WithTooltip: Story = {
args: {
icon: SvgUser,
children: "Everyone",
tooltip: "Filter by creator",
tooltipSide: "bottom",
},
};

View File

@@ -0,0 +1,70 @@
# FilterButton
**Import:** `import { FilterButton, type FilterButtonProps } from "@opal/components";`
A stateful filter trigger with a built-in chevron (when empty) and a clear button (when selected). Hardcodes `variant="select-filter"` and delegates to `Interactive.Stateful`, adding automatic open-state detection from Radix `data-state`. Designed to sit inside a `Popover.Trigger` for filter dropdowns.
## Relationship to OpenButton
FilterButton shares a similar call stack to `OpenButton`:
```
Interactive.Stateful → Interactive.Container → content row (icon + label + trailing indicator)
```
FilterButton is a **narrower, filter-specific** variant:
- It hardcodes `variant="select-filter"` (OpenButton uses `"select-heavy"`)
- It exposes `active?: boolean` instead of the raw `state` prop (maps to `"selected"` / `"empty"` internally)
- When active, the chevron is hidden via `visibility` and an absolutely-positioned clear `Button` with `prominence="tertiary"` overlays it — placed as a sibling outside the `<button>` to avoid nesting buttons
- It uses the shared `ChevronIcon` from `buttons/chevron` (same as OpenButton)
- It does not support `foldable`, `size`, or `width` — it is always `"lg"`
## Architecture
```
div.relative <- bounding wrapper
Interactive.Stateful <- variant="select-filter", interaction, state
└─ Interactive.Container (button) <- height="lg", default rounding/padding
└─ div.interactive-foreground
├─ div > Icon (interactive-foreground-icon)
├─ <span> label text
└─ ChevronIcon (when empty)
OR spacer div (when selected — reserves chevron space)
div.absolute <- clear Button overlay (when selected)
└─ Button (SvgX, size="2xs", prominence="tertiary")
```
- **Open-state detection** reads `data-state="open"` injected by Radix triggers (e.g. `Popover.Trigger`), falling back to the explicit `interaction` prop.
- **Chevron rotation** uses the shared `ChevronIcon` component and `buttons/chevron.css`, which rotates 180deg when `data-interaction="hover"`.
- **Clear button** is absolutely positioned outside the `<button>` element tree to avoid invalid nested `<button>` elements. An invisible spacer inside the button reserves the same space so layout doesn't shift between states.
## Props
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `icon` | `IconFunctionComponent` | **required** | Left icon component |
| `children` | `string` | **required** | Label text between icon and trailing indicator |
| `active` | `boolean` | `false` | Whether the filter has an active selection |
| `onClear` | `() => void` | **required** | Called when the clear (X) button is clicked |
| `interaction` | `"rest" \| "hover" \| "active"` | auto | JS-controlled interaction override. Falls back to Radix `data-state="open"`. |
| `tooltip` | `string` | — | Tooltip text shown on hover |
| `tooltipSide` | `TooltipSide` | `"top"` | Which side the tooltip appears on |
## Usage
```tsx
import { FilterButton } from "@opal/components";
import { SvgUser } from "@opal/icons";
// Inside a Popover (auto-detects open state)
<Popover.Trigger asChild>
<FilterButton
icon={SvgUser}
active={hasSelection}
onClear={() => clearSelection()}
>
{hasSelection ? selectionLabel : "Everyone"}
</FilterButton>
</Popover.Trigger>
```

View File

@@ -0,0 +1,120 @@
import {
Interactive,
type InteractiveStatefulInteraction,
type InteractiveStatefulProps,
} from "@opal/core";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
import { SvgX } from "@opal/icons";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
import { ChevronIcon } from "@opal/components/buttons/chevron";
import { Button } from "@opal/components/buttons/button/components";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface FilterButtonProps
extends Omit<InteractiveStatefulProps, "variant" | "state"> {
/** Left icon — always visible. */
icon: IconFunctionComponent;
/** Label text between icon and trailing indicator. */
children: string;
/** Whether the filter has an active selection. @default false */
active?: boolean;
/** Called when the clear (X) button is clicked in active state. */
onClear: () => void;
/** Tooltip text shown on hover. */
tooltip?: string;
/** Which side the tooltip appears on. */
tooltipSide?: TooltipSide;
}
// ---------------------------------------------------------------------------
// FilterButton
// ---------------------------------------------------------------------------
function FilterButton({
icon: Icon,
children,
onClear,
tooltip,
tooltipSide = "top",
active = false,
interaction,
...statefulProps
}: FilterButtonProps) {
// Derive open state: explicit prop > Radix data-state (injected via Slot chain)
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
| string
| undefined;
const resolvedInteraction: InteractiveStatefulInteraction =
interaction ?? (dataState === "open" ? "hover" : "rest");
const button = (
<div className="relative">
<Interactive.Stateful
{...statefulProps}
variant="select-filter"
interaction={resolvedInteraction}
state={active ? "selected" : "empty"}
>
<Interactive.Container type="button">
<div className="interactive-foreground flex flex-row items-center gap-1">
{iconWrapper(Icon, "lg", true)}
<span className="whitespace-nowrap font-main-ui-action">
{children}
</span>
<div style={{ visibility: active ? "hidden" : "visible" }}>
{iconWrapper(ChevronIcon, "lg", true)}
</div>
</div>
</Interactive.Container>
</Interactive.Stateful>
{active && (
<div className="absolute right-2 top-1/2 -translate-y-1/2">
{/* Force hover state so the X stays visually prominent against
the inverted selected background — without this it renders
dimmed and looks disabled. */}
<Button
icon={SvgX}
size="2xs"
prominence="tertiary"
tooltip="Clear filter"
interaction="hover"
onClick={(e) => {
e.stopPropagation();
onClear();
}}
/>
</div>
)}
</div>
);
if (!tooltip) return button;
return (
<TooltipPrimitive.Root>
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
<TooltipPrimitive.Portal>
<TooltipPrimitive.Content
className="opal-tooltip"
side={tooltipSide}
sideOffset={4}
>
{tooltip}
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
</TooltipPrimitive.Root>
);
}
export { FilterButton, type FilterButtonProps };

View File

@@ -1,8 +1,5 @@
import "@opal/components/tooltip.css";
import {
Interactive,
type InteractiveStatefulState,
type InteractiveStatefulInteraction,
type InteractiveStatefulProps,
InteractiveContainerRoundingVariant,
} from "@opal/core";
@@ -22,40 +19,26 @@ type ContentPassthroughProps = DistributiveOmit<
"paddingVariant" | "widthVariant" | "ref" | "withInteractive"
>;
type LineItemButtonOwnProps = {
type LineItemButtonOwnProps = Pick<
InteractiveStatefulProps,
| "state"
| "interaction"
| "onClick"
| "href"
| "target"
| "group"
| "ref"
| "type"
> & {
/** Interactive select variant. @default "select-light" */
selectVariant?: "select-light" | "select-heavy";
/** Value state. @default "empty" */
state?: InteractiveStatefulState;
/** JS-controllable interaction state override. @default "rest" */
interaction?: InteractiveStatefulInteraction;
/** Click handler. */
onClick?: InteractiveStatefulProps["onClick"];
/** When provided, renders an anchor instead of a div. */
href?: string;
/** Anchor target (e.g. "_blank"). */
target?: string;
/** Interactive group key. */
group?: string;
/** Forwarded ref. */
ref?: React.Ref<HTMLElement>;
/** Corner rounding preset (height is always content-driven). @default "default" */
roundingVariant?: InteractiveContainerRoundingVariant;
/** Container width. @default "full" */
width?: ExtremaSizeVariants;
/** HTML button type. @default "button" */
type?: "submit" | "button" | "reset";
/** Tooltip text shown on hover. */
tooltip?: string;
@@ -79,11 +62,11 @@ function LineItemButton({
target,
group,
ref,
type = "button",
// Sizing
roundingVariant = "default",
width = "full",
type = "button",
tooltip,
tooltipSide = "top",

View File

@@ -40,13 +40,6 @@ export const Open: Story = {
},
};
export const Disabled: Story = {
args: {
disabled: true,
children: "Disabled",
},
};
export const Foldable: Story = {
args: {
foldable: true,

View File

@@ -1,5 +1,3 @@
import "@opal/components/buttons/open-button/styles.css";
import "@opal/components/tooltip.css";
import {
Interactive,
useDisabled,
@@ -9,24 +7,11 @@ import {
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
import type { InteractiveContainerRoundingVariant } from "@opal/core";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent, IconProps } from "@opal/types";
import { SvgChevronDownSmall } from "@opal/icons";
import type { IconFunctionComponent } from "@opal/types";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
import { cn } from "@opal/utils";
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
// ---------------------------------------------------------------------------
// Chevron (stable identity — never causes React to remount the SVG)
// ---------------------------------------------------------------------------
function ChevronIcon({ className, ...props }: IconProps) {
return (
<SvgChevronDownSmall
className={cn(className, "opal-open-button-chevron")}
{...props}
/>
);
}
import { ChevronIcon } from "@opal/components/buttons/chevron";
// ---------------------------------------------------------------------------
// Types

View File

@@ -1,8 +0,0 @@
.opal-open-button-chevron {
transition: rotate 200ms ease;
}
.interactive[data-interaction="hover"] .opal-open-button-chevron,
.interactive[data-interaction="active"] .opal-open-button-chevron {
rotate: -180deg;
}

View File

@@ -1,5 +1,4 @@
import "@opal/components/buttons/select-button/styles.css";
import "@opal/components/tooltip.css";
import {
Interactive,
useDisabled,
@@ -50,9 +49,6 @@ type SelectButtonProps = InteractiveStatefulProps &
*/
size?: ContainerSizeVariants;
/** HTML button type. Container renders a `<button>` element. */
type?: "submit" | "button" | "reset";
/** Tooltip text shown on hover. */
tooltip?: string;

View File

@@ -1,3 +1,5 @@
import "@opal/components/tooltip.css";
/* Shared types */
export type TooltipSide = "top" | "bottom" | "left" | "right";
@@ -19,6 +21,12 @@ export {
type OpenButtonProps,
} from "@opal/components/buttons/open-button/components";
/* FilterButton */
export {
FilterButton,
type FilterButtonProps,
} from "@opal/components/buttons/filter-button/components";
/* LineItemButton */
export {
LineItemButton,

View File

@@ -32,7 +32,13 @@ function ColumnVisibilityPopover<TData extends RowData>({
// User-defined columns only (exclude internal qualifier/actions)
const dataColumns = table
.getAllLeafColumns()
.filter((col) => !col.id.startsWith("__") && col.id !== "qualifier");
.filter(
(col) =>
!col.id.startsWith("__") &&
col.id !== "qualifier" &&
typeof col.columnDef.header === "string" &&
col.columnDef.header.trim() !== ""
);
return (
<Popover open={open} onOpenChange={setOpen}>

View File

@@ -88,9 +88,12 @@ function HoverableRoot({
ref,
onMouseEnter: consumerMouseEnter,
onMouseLeave: consumerMouseLeave,
onFocusCapture: consumerFocusCapture,
onBlurCapture: consumerBlurCapture,
...props
}: HoverableRootProps) {
const [hovered, setHovered] = useState(false);
const [focused, setFocused] = useState(false);
const onMouseEnter = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
@@ -108,16 +111,40 @@ function HoverableRoot({
[consumerMouseLeave]
);
const onFocusCapture = useCallback(
(e: React.FocusEvent<HTMLDivElement>) => {
setFocused(true);
consumerFocusCapture?.(e);
},
[consumerFocusCapture]
);
const onBlurCapture = useCallback(
(e: React.FocusEvent<HTMLDivElement>) => {
if (
!(e.relatedTarget instanceof Node) ||
!e.currentTarget.contains(e.relatedTarget)
) {
setFocused(false);
}
consumerBlurCapture?.(e);
},
[consumerBlurCapture]
);
const active = hovered || focused;
const GroupContext = getOrCreateContext(group);
return (
<GroupContext.Provider value={hovered}>
<GroupContext.Provider value={active}>
<div
{...props}
ref={ref}
className={cn(widthVariants[widthVariant])}
onMouseEnter={onMouseEnter}
onMouseLeave={onMouseLeave}
onFocusCapture={onFocusCapture}
onBlurCapture={onBlurCapture}
>
{children}
</div>

View File

@@ -16,3 +16,15 @@
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-local="true"]:hover {
opacity: 1;
}
/* Focus — item (or a focusable descendant) receives keyboard focus */
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:has(:focus-visible) {
opacity: 1;
}
/* Focus ring on keyboard focus */
.hoverable-item:focus-visible {
outline: 2px solid var(--border-04);
outline-offset: 2px;
border-radius: 0.25rem;
}

View File

@@ -3,7 +3,7 @@ import type { Route } from "next";
import "@opal/core/interactive/shared.css";
import React from "react";
import { cn } from "@opal/utils";
import type { WithoutStyles } from "@opal/types";
import type { ButtonType, WithoutStyles } from "@opal/types";
import {
containerSizeVariants,
type ContainerSizeVariants,
@@ -52,7 +52,7 @@ interface InteractiveContainerProps
*
* Mutually exclusive with `href`.
*/
type?: "submit" | "button" | "reset";
type?: ButtonType;
/**
* When `true`, applies a 1px border using the theme's border color.

View File

@@ -8,7 +8,7 @@ Stateful interactive surface primitive for elements that maintain a value state
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `variant` | `"select-light" \| "select-heavy" \| "sidebar"` | `"select-heavy"` | Color variant |
| `variant` | `"select-light" \| "select-heavy" \| "select-tinted" \| "select-filter" \| "sidebar"` | `"select-heavy"` | Color variant |
| `state` | `"empty" \| "filled" \| "selected"` | `"empty"` | Current value state |
| `interaction` | `"rest" \| "hover" \| "active"` | `"rest"` | JS-controlled interaction override |
| `group` | `string` | — | Tailwind group class for `group-hover:*` |

View File

@@ -4,7 +4,7 @@ import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import { useDisabled } from "@opal/core/disabled/components";
import type { WithoutStyles } from "@opal/types";
import type { ButtonType, WithoutStyles } from "@opal/types";
// ---------------------------------------------------------------------------
// Types
@@ -14,6 +14,7 @@ type InteractiveStatefulVariant =
| "select-light"
| "select-heavy"
| "select-tinted"
| "select-filter"
| "sidebar";
type InteractiveStatefulState = "empty" | "filled" | "selected";
type InteractiveStatefulInteraction = "rest" | "hover" | "active";
@@ -30,6 +31,8 @@ interface InteractiveStatefulProps
*
* - `"select-light"` — transparent selected background (for inline toggles)
* - `"select-heavy"` — tinted selected background (for list rows, model pickers)
* - `"select-tinted"` — like select-heavy but with a tinted rest background
* - `"select-filter"` — like select-tinted for empty/filled; selected state uses inverted tint backgrounds and inverted text (for filter buttons)
* - `"sidebar"` — for sidebar navigation items
*
* @default "select-heavy"
@@ -63,6 +66,13 @@ interface InteractiveStatefulProps
*/
group?: string;
/**
* HTML button type. When set to `"submit"`, `"button"`, or `"reset"`, the
* element is treated as inherently interactive for cursor styling purposes
* even without an explicit `onClick` or `href`.
*/
type?: ButtonType;
/**
* URL to navigate to when clicked. Passed through Slot to the child.
*/
@@ -94,6 +104,7 @@ function InteractiveStateful({
state = "empty",
interaction = "rest",
group,
type,
href,
target,
...props
@@ -104,7 +115,7 @@ function InteractiveStateful({
// so Radix Slot-injected handlers don't bypass this guard.
const classes = cn(
"interactive",
!props.onClick && !href && "!cursor-default !select-auto",
!props.onClick && !href && !type && "!cursor-default !select-auto",
group
);

View File

@@ -308,6 +308,89 @@
--interactive-foreground-icon: var(--action-link-03);
}
/* ===========================================================================
Select-Filter — empty/filled identical to Select-Tinted;
selected uses inverted tint backgrounds and inverted text
=========================================================================== */
/* ---------------------------------------------------------------------------
Select-Filter — Empty & Filled (identical colors)
--------------------------------------------------------------------------- */
.interactive[data-interactive-variant="select-filter"]:is(
[data-interactive-state="empty"],
[data-interactive-state="filled"]
) {
@apply bg-background-tint-01;
--interactive-foreground: var(--text-02);
--interactive-foreground-icon: var(--text-02);
}
.interactive[data-interactive-variant="select-filter"]:is(
[data-interactive-state="empty"],
[data-interactive-state="filled"]
):hover:not([data-disabled]),
.interactive[data-interactive-variant="select-filter"]:is(
[data-interactive-state="empty"],
[data-interactive-state="filled"]
)[data-interaction="hover"]:not([data-disabled]) {
@apply bg-background-tint-02;
--interactive-foreground: var(--text-04);
--interactive-foreground-icon: var(--text-04);
}
.interactive[data-interactive-variant="select-filter"]:is(
[data-interactive-state="empty"],
[data-interactive-state="filled"]
):active:not([data-disabled]),
.interactive[data-interactive-variant="select-filter"]:is(
[data-interactive-state="empty"],
[data-interactive-state="filled"]
)[data-interaction="active"]:not([data-disabled]) {
@apply bg-background-neutral-00;
--interactive-foreground: var(--text-05);
--interactive-foreground-icon: var(--text-05);
}
.interactive[data-interactive-variant="select-filter"]:is(
[data-interactive-state="empty"],
[data-interactive-state="filled"]
)[data-disabled] {
@apply bg-transparent;
--interactive-foreground: var(--text-01);
--interactive-foreground-icon: var(--text-01);
}
/* ---------------------------------------------------------------------------
Select-Filter — Selected
--------------------------------------------------------------------------- */
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"] {
@apply bg-background-tint-inverted-03;
--interactive-foreground: var(--text-inverted-05);
--interactive-foreground-icon: var(--text-inverted-05);
}
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"]:hover:not(
[data-disabled]
),
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"][data-interaction="hover"]:not(
[data-disabled]
) {
@apply bg-background-tint-inverted-04;
--interactive-foreground: var(--text-inverted-05);
--interactive-foreground-icon: var(--text-inverted-05);
}
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"]:active:not(
[data-disabled]
),
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"][data-interaction="active"]:not(
[data-disabled]
) {
@apply bg-background-tint-inverted-04;
--interactive-foreground: var(--text-inverted-04);
--interactive-foreground-icon: var(--text-inverted-04);
}
.interactive[data-interactive-variant="select-filter"][data-interactive-state="selected"][data-disabled] {
@apply bg-background-neutral-04;
--interactive-foreground: var(--text-inverted-04);
--interactive-foreground-icon: var(--text-inverted-02);
}
/* ===========================================================================
Sidebar
=========================================================================== */

View File

@@ -4,7 +4,7 @@ import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import { useDisabled } from "@opal/core/disabled/components";
import type { WithoutStyles } from "@opal/types";
import type { ButtonType, WithoutStyles } from "@opal/types";
// ---------------------------------------------------------------------------
// Types
@@ -53,6 +53,13 @@ interface InteractiveStatelessProps
*/
group?: string;
/**
* HTML button type. When set to `"submit"`, `"button"`, or `"reset"`, the
* element is treated as inherently interactive for cursor styling purposes
* even without an explicit `onClick` or `href`.
*/
type?: ButtonType;
/**
* URL to navigate to when clicked. Passed through Slot to the child.
*/
@@ -85,6 +92,7 @@ function InteractiveStateless({
prominence = "primary",
interaction = "rest",
group,
type,
href,
target,
...props
@@ -95,7 +103,7 @@ function InteractiveStateless({
// so Radix Slot-injected handlers don't bypass this guard.
const classes = cn(
"interactive",
!props.onClick && !href && "!cursor-default !select-auto",
!props.onClick && !href && !type && "!cursor-default !select-auto",
group
);

View File

@@ -0,0 +1,20 @@
import type { IconProps } from "@opal/types";
const SvgEyeOff = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M11.78 11.78C10.6922 12.6092 9.36761 13.0685 8 13.0909C3.54545 13.0909 1 8 1 8C1.79157 6.52484 2.88945 5.23602 4.22 4.22M11.78 11.78L9.34909 9.34909M11.78 11.78L15 15M4.22 4.22L1 1M4.22 4.22L6.65091 6.65091M6.66364 3.06182C7.10167 2.95929 7.55013 2.90803 8 2.90909C12.4545 2.90909 15 8 15 8C14.6137 8.72266 14.153 9.40301 13.6255 10.03M9.34909 9.34909L6.65091 6.65091M9.34909 9.34909C8.99954 9.72422 8.49873 9.94737 7.98606 9.95641C6.922 9.97519 6.02481 9.078 6.04358 8.01394C6.05263 7.50127 6.27578 7.00046 6.65091 6.65091"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgEyeOff;

View File

@@ -68,6 +68,7 @@ export { default as SvgExpand } from "@opal/icons/expand";
export { default as SvgExternalLink } from "@opal/icons/external-link";
export { default as SvgEye } from "@opal/icons/eye";
export { default as SvgEyeClosed } from "@opal/icons/eye-closed";
export { default as SvgEyeOff } from "@opal/icons/eye-off";
export { default as SvgFiles } from "@opal/icons/files";
export { default as SvgFileBraces } from "@opal/icons/file-braces";
export { default as SvgFileChartPie } from "@opal/icons/file-chart-pie";
@@ -120,7 +121,9 @@ export { default as SvgNetworkGraph } from "@opal/icons/network-graph";
export { default as SvgNotificationBubble } from "@opal/icons/notification-bubble";
export { default as SvgOllama } from "@opal/icons/ollama";
export { default as SvgOnyxLogo } from "@opal/icons/onyx-logo";
export { default as SvgOnyxLogoTyped } from "@opal/icons/onyx-logo-typed";
export { default as SvgOnyxOctagon } from "@opal/icons/onyx-octagon";
export { default as SvgOnyxTyped } from "@opal/icons/onyx-typed";
export { default as SvgOpenai } from "@opal/icons/openai";
export { default as SvgOpenrouter } from "@opal/icons/openrouter";
export { default as SvgOrganization } from "@opal/icons/organization";

View File

@@ -0,0 +1,27 @@
import SvgOnyxLogo from "@opal/icons/onyx-logo";
import SvgOnyxTyped from "@opal/icons/onyx-typed";
import { cn } from "@opal/utils";
interface OnyxLogoTypedProps {
size?: number;
className?: string;
}
// # NOTE(@raunakab):
// This ratio is not some random, magical number; it is available on Figma.
const HEIGHT_TO_GAP_RATIO = 5 / 16;
const SvgOnyxLogoTyped = ({ size: height, className }: OnyxLogoTypedProps) => {
const gap = height != null ? height * HEIGHT_TO_GAP_RATIO : undefined;
return (
<div
className={cn(`flex flex-row items-center`, className)}
style={{ gap }}
>
<SvgOnyxLogo size={height} />
<SvgOnyxTyped size={height} />
</div>
);
};
export default SvgOnyxLogoTyped;

View File

@@ -1,19 +1,27 @@
import type { IconProps } from "@opal/types";
const SvgOnyxLogo = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 56 56"
viewBox="0 0 64 64"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
fillRule="evenodd"
clipRule="evenodd"
d="M28 0 10.869 7.77 28 15.539l17.131-7.77L28 0Zm0 40.461-17.131 7.77L28 56l17.131-7.77L28 40.461Zm20.231-29.592L56 28.001l-7.769 17.131L40.462 28l7.769-17.131ZM15.538 28 7.77 10.869 0 28l7.769 17.131L15.538 28Z"
fill="currentColor"
d="M10.4014 13.25L18.875 32L10.3852 50.75L2 32L10.4014 13.25Z"
fill="var(--theme-primary-05)"
/>
<path
d="M53.5264 13.25L62 32L53.5102 50.75L45.125 32L53.5264 13.25Z"
fill="var(--theme-primary-05)"
/>
<path
d="M32 45.125L50.75 53.5625L32 62L13.25 53.5625L32 45.125Z"
fill="var(--theme-primary-05)"
/>
<path
d="M32 2L50.75 10.4375L32 18.875L13.25 10.4375L32 2Z"
fill="var(--theme-primary-05)"
/>
</svg>
);

View File

@@ -0,0 +1,28 @@
import type { IconProps } from "@opal/types";
const SvgOnyxTyped = ({ size, ...props }: IconProps) => (
<svg
height={size}
viewBox="0 0 152 64"
fill="none"
xmlns="http://www.w3.org/2000/svg"
{...props}
>
<path
d="M19.1795 51.2136C15.6695 51.2136 12.4353 50.3862 9.47691 48.7315C6.56865 47.0768 4.2621 44.8454 2.55726 42.0374C0.85242 39.1793 0 36.0955 0 32.7861C0 30.279 0.451281 27.9223 1.35384 25.716C2.30655 23.4596 3.76068 21.3285 5.71623 19.3228L11.8085 13.08C12.4604 12.6789 13.4131 12.3529 14.6666 12.1022C15.9202 11.8014 17.2991 11.6509 18.8034 11.6509C22.3134 11.6509 25.5225 12.4783 28.4307 14.133C31.3891 15.7877 33.7208 18.0441 35.4256 20.9023C37.1304 23.7103 37.9829 26.794 37.9829 30.1536C37.9829 32.6106 37.5065 34.9673 36.5538 37.2237C35.6512 39.4802 34.147 41.6864 32.041 43.8426L26.3248 49.7845C25.3219 50.2358 24.2188 50.5868 23.0154 50.8375C21.8621 51.0882 20.5835 51.2136 19.1795 51.2136ZM20.1572 43.8426C21.8621 43.8426 23.4917 43.4164 25.0461 42.5639C26.6005 41.6614 27.8541 40.3577 28.8068 38.6528C29.8097 36.948 30.3111 34.9172 30.3111 32.5605C30.3111 30.0032 29.6843 27.6966 28.4307 25.6408C27.2273 23.5849 25.6478 21.9803 23.6923 20.8271C21.7869 19.6236 19.8313 19.0219 17.8256 19.0219C16.0706 19.0219 14.4159 19.4732 12.8615 20.3758C11.3573 21.2282 10.1288 22.5068 9.17606 24.2117C8.22335 25.9166 7.747 27.9473 7.747 30.304C7.747 32.8613 8.34871 35.1679 9.55212 37.2237C10.7555 39.2796 12.31 40.9092 14.2154 42.1127C16.1709 43.2659 18.1515 43.8426 20.1572 43.8426Z"
fill="var(--theme-primary-05)"
/>
<path
d="M42.6413 50.4614V12.4031H50.6891V17.7433L55.5028 12.7039C56.0544 12.4532 56.8065 12.2276 57.7592 12.027C58.7621 11.7763 59.8903 11.6509 61.1438 11.6509C64.0521 11.6509 66.5843 12.3028 68.7404 13.6065C70.9467 14.8601 72.6264 16.6401 73.7797 18.9467C74.9831 21.2533 75.5848 23.961 75.5848 27.0698V50.4614H67.6122V29.1006C67.6122 26.9946 67.2612 25.1895 66.5592 23.6852C65.9074 22.1308 64.9547 20.9775 63.7011 20.2253C62.4977 19.4231 61.0686 19.0219 59.4139 19.0219C56.7564 19.0219 54.6253 19.9245 53.0208 21.7296C51.4663 23.4846 50.6891 25.9416 50.6891 29.1006V50.4614H42.6413Z"
fill="var(--theme-primary-05)"
/>
<path
d="M82.3035 64V56.0273H89.9753C91.2288 56.0273 92.2066 55.7264 92.9086 55.1247C93.6607 54.523 94.2625 53.5452 94.7137 52.1913L108.027 12.4031H116.751L103.664 49.4084C103.062 51.1634 102.461 52.5173 101.859 53.47C101.307 54.4227 100.53 55.4506 99.5274 56.5538L92.4573 64H82.3035ZM90.7274 46.6255L76.9633 12.4031H85.989L99.4522 46.6255H90.7274Z"
fill="var(--theme-primary-05)"
/>
<path
d="M115.657 50.4614L129.045 31.2066L116.033 12.4031H125.435L134.085 24.8134L142.358 12.4031H151.308L138.372 31.0562L151.684 50.4614H142.358L133.332 37.3742L124.683 50.4614H115.657Z"
fill="var(--theme-primary-05)"
/>
</svg>
);
export default SvgOnyxTyped;

View File

@@ -32,6 +32,8 @@ interface ContentMdPresetConfig {
optionalFont: string;
/** Aux icon size = lineHeight 2 × p-0.5. */
auxIconSize: string;
/** Left indent for the description so it aligns with the title (past the icon). */
descriptionIndent: string;
}
interface ContentMdProps {
@@ -85,6 +87,7 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
editButtonPadding: "p-0",
optionalFont: "font-main-content-muted",
auxIconSize: "1.25rem",
descriptionIndent: "1.625rem",
},
"main-ui": {
iconSize: "1rem",
@@ -97,6 +100,7 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
editButtonPadding: "p-0",
optionalFont: "font-main-ui-muted",
auxIconSize: "1rem",
descriptionIndent: "1.375rem",
},
secondary: {
iconSize: "0.75rem",
@@ -109,6 +113,7 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
editButtonPadding: "p-0",
optionalFont: "font-secondary-action",
auxIconSize: "0.75rem",
descriptionIndent: "1.125rem",
},
};
@@ -163,22 +168,25 @@ function ContentMd({
data-interactive={withInteractive || undefined}
style={{ gap: config.gap }}
>
{Icon && (
<div
className={cn(
"opal-content-md-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className={cn("opal-content-md-icon", config.iconColorClass)}
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
<div
className="opal-content-md-header"
data-editing={editing || undefined}
>
{Icon && (
<div
className={cn(
"opal-content-md-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className={cn("opal-content-md-icon", config.iconColorClass)}
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
<div className="opal-content-md-body">
<div className="opal-content-md-title-row">
{editing ? (
<div className="opal-content-md-input-sizer">
@@ -274,13 +282,16 @@ function ContentMd({
</div>
)}
</div>
{description && (
<div className="opal-content-md-description font-secondary-body text-text-03">
{description}
</div>
)}
</div>
{description && (
<div
className="opal-content-md-description font-secondary-body text-text-03"
style={Icon ? { paddingLeft: config.descriptionIndent } : undefined}
>
{description}
</div>
)}
</div>
);
}

View File

@@ -224,7 +224,16 @@
--------------------------------------------------------------------------- */
.opal-content-md {
@apply flex flex-row items-start;
@apply flex flex-col items-start;
}
.opal-content-md-header {
@apply flex flex-row items-center w-full;
}
.opal-content-md-header[data-editing] {
@apply rounded-08;
box-shadow: inset 0 0 0 1px var(--border-02);
}
/* ---------------------------------------------------------------------------
@@ -237,15 +246,6 @@
justify-content: center;
}
/* ---------------------------------------------------------------------------
Body column
--------------------------------------------------------------------------- */
.opal-content-md-body {
@apply flex flex-1 flex-col items-start;
min-width: 0.0625rem;
}
/* ---------------------------------------------------------------------------
Title row — title (or input) + edit button
--------------------------------------------------------------------------- */
@@ -267,6 +267,7 @@
.opal-content-md-input-sizer {
display: inline-grid;
align-items: stretch;
width: 100%;
}
.opal-content-md-input-sizer > * {

View File

@@ -86,6 +86,15 @@ export interface IconProps extends SVGProps<SVGSVGElement> {
/** Strips `className` and `style` from a props type to enforce design-system styling. */
export type WithoutStyles<T> = Omit<T, "className" | "style">;
/**
* HTML button `type` attribute values.
*
* Used by interactive primitives and button-like components to indicate that
* the element is inherently interactive for cursor-styling purposes, even
* without an explicit `onClick` or `href`.
*/
export type ButtonType = "submit" | "button" | "reset";
/** Like `Omit` but distributes over union types, preserving discriminated unions. */
export type DistributiveOmit<T, K extends keyof any> = T extends any
? Omit<T, K>

View File

@@ -1,320 +0,0 @@
"use client";
import Text from "@/refresh-components/texts/Text";
import { Persona } from "./interfaces";
import { useRouter } from "next/navigation";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import { toast } from "@/hooks/useToast";
import { useState, useMemo, useEffect } from "react";
import { UniqueIdentifier } from "@dnd-kit/core";
import { DraggableTable } from "@/components/table/DraggableTable";
import {
deletePersona,
personaComparator,
togglePersonaFeatured,
togglePersonaVisibility,
} from "./lib";
import { FiEdit2 } from "react-icons/fi";
import { useUser } from "@/providers/UserProvider";
import { Button } from "@opal/components";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import { SvgAlertCircle, SvgTrash } from "@opal/icons";
import type { Route } from "next";
function PersonaTypeDisplay({ persona }: { persona: Persona }) {
if (persona.builtin_persona) {
return <Text as="p">Built-In</Text>;
}
if (persona.is_featured) {
return <Text as="p">Featured</Text>;
}
if (persona.is_public) {
return <Text as="p">Public</Text>;
}
if (persona.groups.length > 0 || persona.users.length > 0) {
return <Text as="p">Shared</Text>;
}
return (
<Text as="p">Personal {persona.owner && <>({persona.owner.email})</>}</Text>
);
}
export function PersonasTable({
personas,
refreshPersonas,
currentPage,
pageSize,
}: {
personas: Persona[];
refreshPersonas: () => void;
currentPage: number;
pageSize: number;
}) {
const router = useRouter();
const { refreshUser, isAdmin } = useUser();
const editablePersonas = useMemo(() => {
return personas.filter((p) => !p.builtin_persona);
}, [personas]);
const editablePersonaIds = useMemo(() => {
return new Set(editablePersonas.map((p) => p.id.toString()));
}, [editablePersonas]);
const [finalPersonas, setFinalPersonas] = useState<Persona[]>([]);
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
const [personaToDelete, setPersonaToDelete] = useState<Persona | null>(null);
const [defaultModalOpen, setDefaultModalOpen] = useState(false);
const [personaToToggleDefault, setPersonaToToggleDefault] =
useState<Persona | null>(null);
useEffect(() => {
const editable = editablePersonas.sort(personaComparator);
const nonEditable = personas
.filter((p) => !editablePersonaIds.has(p.id.toString()))
.sort(personaComparator);
setFinalPersonas([...editable, ...nonEditable]);
}, [editablePersonas, personas, editablePersonaIds]);
const updatePersonaOrder = async (orderedPersonaIds: UniqueIdentifier[]) => {
const reorderedPersonas = orderedPersonaIds.map(
(id) => personas.find((persona) => persona.id.toString() === id)!
);
setFinalPersonas(reorderedPersonas);
// Calculate display_priority based on current page.
// Page 1 (items 0-9): priorities 0-9
// Page 2 (items 10-19): priorities 10-19, etc.
const pageStartIndex = (currentPage - 1) * pageSize;
const displayPriorityMap = new Map<UniqueIdentifier, number>();
orderedPersonaIds.forEach((personaId, ind) => {
displayPriorityMap.set(personaId, pageStartIndex + ind);
});
const response = await fetch("/api/admin/agents/display-priorities", {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
display_priority_map: Object.fromEntries(displayPriorityMap),
}),
});
if (!response.ok) {
toast.error(`Failed to update persona order - ${await response.text()}`);
setFinalPersonas(personas);
await refreshPersonas();
return;
}
await refreshPersonas();
await refreshUser();
};
const openDeleteModal = (persona: Persona) => {
setPersonaToDelete(persona);
setDeleteModalOpen(true);
};
const closeDeleteModal = () => {
setDeleteModalOpen(false);
setPersonaToDelete(null);
};
const handleDeletePersona = async () => {
if (personaToDelete) {
const response = await deletePersona(personaToDelete.id);
if (response.ok) {
refreshPersonas();
closeDeleteModal();
} else {
toast.error(`Failed to delete persona - ${await response.text()}`);
}
}
};
const openDefaultModal = (persona: Persona) => {
setPersonaToToggleDefault(persona);
setDefaultModalOpen(true);
};
const closeDefaultModal = () => {
setDefaultModalOpen(false);
setPersonaToToggleDefault(null);
};
const handleToggleDefault = async () => {
if (personaToToggleDefault) {
const response = await togglePersonaFeatured(
personaToToggleDefault.id,
personaToToggleDefault.is_featured
);
if (response.ok) {
refreshPersonas();
closeDefaultModal();
} else {
toast.error(`Failed to update persona - ${await response.text()}`);
}
}
};
return (
<div>
{deleteModalOpen && personaToDelete && (
<ConfirmationModalLayout
icon={SvgAlertCircle}
title="Delete Agent"
onClose={closeDeleteModal}
submit={<Button onClick={handleDeletePersona}>Delete</Button>}
>
{`Are you sure you want to delete ${personaToDelete.name}?`}
</ConfirmationModalLayout>
)}
{defaultModalOpen &&
personaToToggleDefault &&
(() => {
const isDefault = personaToToggleDefault.is_featured;
const title = isDefault
? "Remove Featured Agent"
: "Set Featured Agent";
const buttonText = isDefault ? "Remove Feature" : "Set as Featured";
const text = isDefault
? `Are you sure you want to remove the featured status of ${personaToToggleDefault.name}?`
: `Are you sure you want to set the featured status of ${personaToToggleDefault.name}?`;
const additionalText = isDefault
? `Removing "${personaToToggleDefault.name}" as a featured agent will not affect its visibility or accessibility.`
: `Setting "${personaToToggleDefault.name}" as a featured agent will make it public and visible to all users. This action cannot be undone.`;
return (
<ConfirmationModalLayout
icon={SvgAlertCircle}
title={title}
onClose={closeDefaultModal}
submit={
<Button onClick={handleToggleDefault}>{buttonText}</Button>
}
>
<div className="flex flex-col gap-2">
<Text as="p">{text}</Text>
<Text as="p" text03>
{additionalText}
</Text>
</div>
</ConfirmationModalLayout>
);
})()}
<DraggableTable
headers={[
"Name",
"Description",
"Type",
"Featured Agent",
"Is Visible",
"Delete",
]}
isAdmin={isAdmin}
rows={finalPersonas.map((persona) => {
const isEditable = editablePersonas.includes(persona);
return {
id: persona.id.toString(),
cells: [
<div key="name" className="flex">
{!persona.builtin_persona && (
<FiEdit2
className="mr-1 my-auto cursor-pointer"
onClick={() =>
router.push(
`/app/agents/edit/${
persona.id
}?u=${Date.now()}&admin=true` as Route
)
}
/>
)}
<p className="text font-medium whitespace-normal break-none">
{persona.name}
</p>
</div>,
<p
key="description"
className="whitespace-normal break-all max-w-2xl"
>
{persona.description}
</p>,
<PersonaTypeDisplay key={persona.id} persona={persona} />,
<div
key="featured"
onClick={() => {
openDefaultModal(persona);
}}
className={`
px-1 py-0.5 rounded flex hover:bg-accent-background-hovered cursor-pointer select-none w-fit items-center gap-2
`}
>
<div className="my-auto flex-none w-22">
{!persona.is_featured ? (
<div className="text-error">Not Featured</div>
) : (
"Featured"
)}
</div>
<Checkbox checked={persona.is_featured} />
</div>,
<div
key="is_visible"
onClick={async () => {
const response = await togglePersonaVisibility(
persona.id,
persona.is_listed
);
if (response.ok) {
refreshPersonas();
} else {
toast.error(
`Failed to update persona - ${await response.text()}`
);
}
}}
className={`
px-1 py-0.5 rounded flex hover:bg-accent-background-hovered cursor-pointer select-none w-fit items-center gap-2
`}
>
<div className="my-auto w-fit">
{!persona.is_listed ? (
<div className="text-error">Hidden</div>
) : (
"Visible"
)}
</div>
<Checkbox checked={persona.is_listed} />
</div>,
<div key="edit" className="flex">
<div className="mr-auto my-auto">
{!persona.builtin_persona && isEditable ? (
<Button
icon={SvgTrash}
prominence="tertiary"
onClick={() => openDeleteModal(persona)}
/>
) : (
<Text as="p">-</Text>
)}
</div>
</div>,
],
staticModifiers: [[1, "lg:w-[250px] xl:w-[400px] 2xl:w-[550px]"]],
};
})}
setRows={updatePersonaOrder}
/>
</div>
);
}

View File

@@ -1,160 +1 @@
"use client";
import { PersonasTable } from "./PersonaTable";
import Text from "@/components/ui/text";
import Title from "@/components/ui/title";
import Separator from "@/refresh-components/Separator";
import { SubLabel } from "@/components/Field";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { useAdminPersonas } from "@/hooks/useAdminPersonas";
import { Persona } from "./interfaces";
import { ThreeDotsLoader } from "@/components/Loading";
import { ErrorCallout } from "@/components/ErrorCallout";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { useState, useEffect } from "react";
import { Pagination } from "@opal/components";
const route = ADMIN_ROUTES.AGENTS;
const PAGE_SIZE = 20;
function MainContent({
personas,
totalItems,
currentPage,
onPageChange,
refreshPersonas,
}: {
personas: Persona[];
totalItems: number;
currentPage: number;
onPageChange: (page: number) => void;
refreshPersonas: () => void;
}) {
// Filter out default/unified assistants.
// NOTE: The backend should already exclude them if includeDefault = false is
// provided. That change was made with the introduction of pagination; we keep
// this filter here for now for backwards compatibility.
const customPersonas = personas.filter((persona) => !persona.builtin_persona);
const totalPages = Math.ceil(totalItems / PAGE_SIZE);
// Clamp currentPage when totalItems shrinks (e.g., deleting the last item on a page)
useEffect(() => {
if (currentPage > totalPages && totalPages > 0) {
onPageChange(totalPages);
}
}, [currentPage, totalPages, onPageChange]);
return (
<div>
<Text className="mb-2">
Agents are a way to build custom search/question-answering experiences
for different use cases.
</Text>
<Text className="mt-2">They allow you to customize:</Text>
<div className="text-sm">
<ul className="list-disc mt-2 ml-4">
<li>
The prompt used by your LLM of choice to respond to the user query
</li>
<li>The documents that are used as context</li>
</ul>
</div>
<div>
<Separator />
<Title>Create an Agent</Title>
<CreateButton href="/app/agents/create?admin=true">
New Agent
</CreateButton>
<Separator />
<Title>Existing Agents</Title>
{totalItems > 0 ? (
<>
<SubLabel>
Agents will be displayed as options on the Chat / Search
interfaces in the order they are displayed below. Agents marked as
hidden will not be displayed. Editable agents are shown at the
top.
</SubLabel>
<PersonasTable
personas={customPersonas}
refreshPersonas={refreshPersonas}
currentPage={currentPage}
pageSize={PAGE_SIZE}
/>
{totalPages > 1 && (
<Pagination
currentPage={currentPage}
totalPages={totalPages}
onChange={onPageChange}
/>
)}
</>
) : (
<div className="mt-6 p-8 border border-border rounded-lg bg-background-weak text-center">
<Text className="text-lg font-medium mb-2">
No custom agents yet
</Text>
<Text className="text-subtle mb-3">
Create your first agent to:
</Text>
<ul className="text-subtle text-sm list-disc text-left inline-block mb-3">
<li>Build department-specific knowledge bases</li>
<li>Create specialized research agents</li>
<li>Set up compliance and policy advisors</li>
</ul>
<Text className="text-subtle text-sm mb-4">
...and so much more!
</Text>
<CreateButton href="/app/agents/create?admin=true">
Create Your First Agent
</CreateButton>
</div>
)}
</div>
</div>
);
}
export default function Page() {
const [currentPage, setCurrentPage] = useState(1);
const { personas, totalItems, isLoading, error, refresh } = useAdminPersonas({
pageNum: currentPage - 1, // Backend uses 0-indexed pages
pageSize: PAGE_SIZE,
});
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
{isLoading && <ThreeDotsLoader />}
{error && (
<ErrorCallout
errorTitle="Failed to load agents"
errorMsg={
error?.info?.message ||
error?.info?.detail ||
"An unknown error occurred"
}
/>
)}
{!isLoading && !error && (
<MainContent
personas={personas}
totalItems={totalItems}
currentPage={currentPage}
onPageChange={setCurrentPage}
refreshPersonas={refresh}
/>
)}
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}
export { default } from "@/refresh-pages/admin/AgentsPage";

View File

@@ -395,7 +395,7 @@ function SeatsCard({
<InputLayouts.Vertical title="Seats">
<InputNumber
value={newSeatCount}
onChange={setNewSeatCount}
onChange={(v) => setNewSeatCount(v ?? 1)}
min={1}
defaultValue={totalSeats}
showReset

View File

@@ -230,7 +230,7 @@ export default function CheckoutView({ onAdjustPlan }: CheckoutViewProps) {
>
<InputNumber
value={seats}
onChange={setSeats}
onChange={(v) => setSeats(v ?? minRequiredSeats)}
min={minRequiredSeats}
defaultValue={minRequiredSeats}
showReset

View File

@@ -260,7 +260,7 @@ export default function VoiceProviderSetupModal({
<SvgArrowExchange className="size-3 text-text-04" />
</div>
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
<SvgOnyxLogo size={24} className="shrink-0" />
</div>
</div>
);

View File

@@ -69,7 +69,7 @@ export const WebProviderSetupModal = memo(
<SvgArrowExchange className="size-3 text-text-04" />
</div>
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
<SvgOnyxLogo size={24} className="shrink-0" />
</div>
</div>
);

View File

@@ -1372,7 +1372,7 @@ export default function Page() {
} logo`,
fallback:
selectedContentProviderType === "onyx_web_crawler" ? (
<SvgOnyxLogo size={24} className="text-text-05" />
<SvgOnyxLogo size={24} />
) : undefined,
size: 24,
containerSize: 28,

View File

@@ -0,0 +1 @@
export { default } from "@/refresh-pages/admin/GroupsPage/CreateGroupPage";

View File

@@ -11,22 +11,15 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { useState, useEffect } from "react";
import { useSettingsContext } from "@/providers/SettingsProvider";
import FrostedDiv from "@/refresh-components/FrostedDiv";
import { cn } from "@/lib/utils";
export interface WelcomeMessageProps {
agent?: MinimalPersonaSnapshot;
isDefaultAgent: boolean;
/** Optional right-aligned element rendered on the same row as the greeting (e.g. model selector). */
rightChildren?: React.ReactNode;
/** When true, the greeting/logo content is hidden (but space is preserved). Used at max models. */
hideTitle?: boolean;
}
export default function WelcomeMessage({
agent,
isDefaultAgent,
rightChildren,
hideTitle,
}: WelcomeMessageProps) {
const settings = useSettingsContext();
const enterpriseSettings = settings?.enterpriseSettings;
@@ -46,10 +39,8 @@ export default function WelcomeMessage({
if (isDefaultAgent) {
content = (
<div data-testid="onyx-logo" className="flex flex-col items-start gap-2">
<div className="flex items-center justify-center size-9 p-0.5">
<Logo folded size={32} />
</div>
<div data-testid="onyx-logo" className="flex flex-row items-center gap-4">
<Logo folded size={32} />
<Text as="p" headingH2>
{greeting}
</Text>
@@ -57,15 +48,17 @@ export default function WelcomeMessage({
);
} else if (agent) {
content = (
<div
data-testid="agent-name-display"
className="flex flex-col items-start gap-2"
>
<AgentAvatar agent={agent} size={36} />
<Text as="p" headingH2>
{agent.name}
</Text>
</div>
<>
<div
data-testid="agent-name-display"
className="flex flex-row items-center gap-3"
>
<AgentAvatar agent={agent} size={36} />
<Text as="p" headingH2>
{agent.name}
</Text>
</div>
</>
);
}
@@ -76,24 +69,9 @@ export default function WelcomeMessage({
return (
<FrostedDiv
data-testid="chat-intro"
wrapperClassName="w-full"
className="flex flex-col items-center justify-center gap-3 w-full max-w-[var(--app-page-main-content-width)] mx-auto"
className="flex flex-col items-center justify-center gap-3 w-full max-w-[var(--app-page-main-content-width)]"
>
{rightChildren ? (
<div className="flex items-end gap-2 w-full">
<div
className={cn(
"flex-1 min-w-0 min-h-[80px] px-2 py-1",
hideTitle && "invisible"
)}
>
{content}
</div>
<div className="shrink-0">{rightChildren}</div>
</div>
) : (
content
)}
{content}
</FrostedDiv>
);
}

View File

@@ -159,10 +159,6 @@ export interface Message {
overridden_model?: string;
stopReason?: StreamStopReason | null;
// Multi-model answer generation
preferredResponseId?: number | null;
modelDisplayName?: string | null;
// new gen
packets: Packet[];
packetCount?: number; // Tracks packet count for React memo comparison (avoids reading from mutated array)
@@ -235,9 +231,6 @@ export interface BackendMessage {
parentMessageId: number | null;
refined_answer_improvement: boolean | null;
is_agentic: boolean | null;
// Multi-model answer generation
preferred_response_id: number | null;
model_display_name: string | null;
}
export interface MessageResponseIDInfo {
@@ -245,12 +238,6 @@ export interface MessageResponseIDInfo {
reserved_assistant_message_id: number; // TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
}
export interface MultiModelMessageResponseIDInfo {
user_message_id: number | null;
reserved_assistant_message_ids: number[];
model_names: string[];
}
export interface UserKnowledgeFilePacket {
user_files: FileDescriptor[];
}

View File

@@ -1,149 +0,0 @@
"use client";
import { useCallback } from "react";
import { Button } from "@opal/components";
import { Hoverable } from "@opal/core";
import { SvgEyeClosed, SvgMoreHorizontal, SvgX } from "@opal/icons";
import Text from "@/refresh-components/texts/Text";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import AgentMessage, {
AgentMessageProps,
} from "@/app/app/message/messageComponents/AgentMessage";
import { cn } from "@/lib/utils";
export interface MultiModelPanelProps {
/** Index of this model in the selectedModels array (used for Hoverable group key) */
modelIndex: number;
/** Provider name for icon lookup */
provider: string;
/** Model name for icon lookup and display */
modelName: string;
/** Display-friendly model name */
displayName: string;
/** Whether this panel is the preferred/selected response */
isPreferred: boolean;
/** Whether this panel is currently hidden */
isHidden: boolean;
/** Whether this is a non-preferred panel in selection mode (pushed off-screen) */
isNonPreferredInSelection: boolean;
/** Callback when user clicks this panel to select as preferred */
onSelect: () => void;
/** Callback to hide/show this panel */
onToggleVisibility: () => void;
/** Props to pass through to AgentMessage */
agentMessageProps: AgentMessageProps;
}
export default function MultiModelPanel({
modelIndex,
provider,
modelName,
displayName,
isPreferred,
isHidden,
isNonPreferredInSelection,
onSelect,
onToggleVisibility,
agentMessageProps,
}: MultiModelPanelProps) {
const ProviderIcon = getProviderIcon(provider, modelName);
const handlePanelClick = useCallback(() => {
if (!isHidden) {
onSelect();
}
}, [isHidden, onSelect]);
// Hidden/collapsed panel — compact strip: icon + strikethrough name + eye icon
if (isHidden) {
return (
<div className="flex items-center gap-1.5 rounded-08 bg-background-tint-00 px-2 py-1 opacity-50">
<div className="flex items-center justify-center size-5 shrink-0">
<ProviderIcon size={16} />
</div>
<Text secondaryBody text02 nowrap className="line-through">
{displayName}
</Text>
<Button
prominence="tertiary"
icon={SvgEyeClosed}
size="2xs"
onClick={onToggleVisibility}
tooltip="Show response"
/>
</div>
);
}
const hoverGroup = `panel-${modelIndex}`;
return (
<Hoverable.Root group={hoverGroup}>
<div
className="flex flex-col min-w-0 gap-3 cursor-pointer"
onClick={handlePanelClick}
>
{/* Panel header */}
<div
className={cn(
"flex items-center gap-1.5 rounded-12 px-2 py-1",
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
)}
>
<div className="flex items-center justify-center size-5 shrink-0">
<ProviderIcon size={16} />
</div>
<Text mainUiAction text04 nowrap className="flex-1 min-w-0 truncate">
{displayName}
</Text>
{isPreferred && (
<Text secondaryBody nowrap className="text-action-link-05 shrink-0">
Preferred Response
</Text>
)}
<Button
prominence="tertiary"
icon={SvgMoreHorizontal}
size="2xs"
tooltip="More"
onClick={(e) => e.stopPropagation()}
/>
<Button
prominence="tertiary"
icon={SvgX}
size="2xs"
onClick={(e) => {
e.stopPropagation();
onToggleVisibility();
}}
tooltip="Hide response"
/>
</div>
{/* "Select This Response" hover affordance */}
{!isPreferred && !isNonPreferredInSelection && (
<Hoverable.Item group={hoverGroup} variant="opacity-on-hover">
<div className="flex justify-center pointer-events-none">
<div className="flex items-center h-6 bg-background-tint-00 rounded-08 px-1 shadow-sm">
<Text
secondaryBody
className="font-semibold text-text-03 px-1 whitespace-nowrap"
>
Select This Response
</Text>
</div>
</div>
</Hoverable.Item>
)}
{/* Response body */}
<div className={cn(isNonPreferredInSelection && "pointer-events-none")}>
<AgentMessage
{...agentMessageProps}
hideFooter={isNonPreferredInSelection}
/>
</div>
</div>
</Hoverable.Root>
);
}

View File

@@ -1,229 +0,0 @@
"use client";
import { useState, useCallback, useMemo } from "react";
import { Packet } from "@/app/app/services/streamingModels";
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
import { FeedbackType, Message } from "@/app/app/interfaces";
import { LlmManager } from "@/lib/hooks";
import { RegenerationFactory } from "@/app/app/message/messageComponents/AgentMessage";
import MultiModelPanel from "@/app/app/message/MultiModelPanel";
import { cn } from "@/lib/utils";
export interface MultiModelResponse {
modelIndex: number;
provider: string;
modelName: string;
displayName: string;
packets: Packet[];
packetCount: number;
nodeId: number;
messageId?: number;
isHighlighted?: boolean;
currentFeedback?: FeedbackType | null;
isGenerating?: boolean;
}
export interface MultiModelResponseViewProps {
responses: MultiModelResponse[];
chatState: FullChatState;
llmManager: LlmManager | null;
onRegenerate?: RegenerationFactory;
parentMessage?: Message | null;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (nodeId: number) => void;
}
export default function MultiModelResponseView({
responses,
chatState,
llmManager,
onRegenerate,
parentMessage,
otherMessagesCanSwitchTo,
onMessageSelection,
}: MultiModelResponseViewProps) {
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
const isGenerating = useMemo(
() => responses.some((r) => r.isGenerating),
[responses]
);
const visibleResponses = useMemo(
() => responses.filter((r) => !hiddenPanels.has(r.modelIndex)),
[responses, hiddenPanels]
);
const hiddenResponses = useMemo(
() => responses.filter((r) => hiddenPanels.has(r.modelIndex)),
[responses, hiddenPanels]
);
const toggleVisibility = useCallback(
(modelIndex: number) => {
setHiddenPanels((prev) => {
const next = new Set(prev);
if (next.has(modelIndex)) {
next.delete(modelIndex);
} else {
// Don't hide the last visible panel
const visibleCount = responses.length - next.size;
if (visibleCount <= 1) return prev;
next.add(modelIndex);
}
return next;
});
},
[responses.length]
);
const handleSelectPreferred = useCallback(
(modelIndex: number) => {
setPreferredIndex(modelIndex);
const response = responses[modelIndex];
if (!response) return;
// Sync with message tree — mark this response as the latest child
// so the next message chains from it.
if (onMessageSelection) {
onMessageSelection(response.nodeId);
}
},
[responses, onMessageSelection]
);
// Selection mode when preferred is set and not generating
const showSelectionMode =
preferredIndex !== null && !isGenerating && visibleResponses.length > 1;
// Build common panel props
const buildPanelProps = useCallback(
(response: MultiModelResponse, isNonPreferred: boolean) => ({
modelIndex: response.modelIndex,
provider: response.provider,
modelName: response.modelName,
displayName: response.displayName,
isPreferred: preferredIndex === response.modelIndex,
isHidden: false as const,
isNonPreferredInSelection: isNonPreferred,
onSelect: () => handleSelectPreferred(response.modelIndex),
onToggleVisibility: () => toggleVisibility(response.modelIndex),
agentMessageProps: {
rawPackets: response.packets,
packetCount: response.packetCount,
chatState,
nodeId: response.nodeId,
messageId: response.messageId,
currentFeedback: response.currentFeedback,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
},
}),
[
preferredIndex,
handleSelectPreferred,
toggleVisibility,
chatState,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
]
);
// Shared renderer for hidden panels (inline in the flex row)
const renderHiddenPanels = () =>
hiddenResponses.map((r) => (
<div key={r.modelIndex} className="w-[240px] shrink-0">
<MultiModelPanel
modelIndex={r.modelIndex}
provider={r.provider}
modelName={r.modelName}
displayName={r.displayName}
isPreferred={false}
isHidden
isNonPreferredInSelection={false}
onSelect={() => handleSelectPreferred(r.modelIndex)}
onToggleVisibility={() => toggleVisibility(r.modelIndex)}
agentMessageProps={buildPanelProps(r, false).agentMessageProps}
/>
</div>
));
if (showSelectionMode) {
// ── Selection Layout ──
// Preferred stays at normal chat width, centered.
// Non-preferred panels are pushed to the viewport edges and clip off-screen.
const preferredIdx = visibleResponses.findIndex(
(r) => r.modelIndex === preferredIndex
);
const preferred = visibleResponses[preferredIdx];
const leftPanels = visibleResponses.slice(0, preferredIdx);
const rightPanels = visibleResponses.slice(preferredIdx + 1);
// Non-preferred panel width and gap between panels
const PANEL_W = 400;
const GAP = 16;
return (
<div className="w-full relative overflow-hidden">
{/* Preferred — centered at normal chat width, in flow to set container height */}
{preferred && (
<div className="w-full max-w-[720px] min-w-[400px] mx-auto">
<MultiModelPanel {...buildPanelProps(preferred, false)} />
</div>
)}
{/* Non-preferred on the left — anchored to the left of the preferred panel */}
{leftPanels.map((r, i) => (
<div
key={r.modelIndex}
className="absolute top-0"
style={{
width: `${PANEL_W}px`,
// Right edge of this panel sits just left of the preferred panel
right: `calc(50% + var(--app-page-main-content-width) / 2 + ${
GAP + i * (PANEL_W + GAP)
}px)`,
}}
>
<MultiModelPanel {...buildPanelProps(r, true)} />
</div>
))}
{/* Non-preferred on the right — anchored to the right of the preferred panel */}
{rightPanels.map((r, i) => (
<div
key={r.modelIndex}
className="absolute top-0"
style={{
width: `${PANEL_W}px`,
// Left edge of this panel sits just right of the preferred panel
left: `calc(50% + var(--app-page-main-content-width) / 2 + ${
GAP + i * (PANEL_W + GAP)
}px)`,
}}
>
<MultiModelPanel {...buildPanelProps(r, true)} />
</div>
))}
</div>
);
}
// ── Generation Layout (equal panels) ──
return (
<div className="flex gap-6 items-start justify-center">
{visibleResponses.map((r) => (
<div key={r.modelIndex} className="flex-1 min-w-[400px] max-w-[720px]">
<MultiModelPanel {...buildPanelProps(r, false)} />
</div>
))}
{renderHiddenPanels()}
</div>
);
}

View File

@@ -49,8 +49,6 @@ export interface AgentMessageProps {
parentMessage?: Message | null;
// Duration in seconds for processing this message (agent messages only)
processingDurationSeconds?: number;
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
hideFooter?: boolean;
}
// TODO: Consider more robust comparisons:
@@ -78,8 +76,7 @@ function arePropsEqual(
prev.parentMessage?.messageId === next.parentMessage?.messageId &&
prev.llmManager?.isLoadingProviders ===
next.llmManager?.isLoadingProviders &&
prev.processingDurationSeconds === next.processingDurationSeconds &&
prev.hideFooter === next.hideFooter
prev.processingDurationSeconds === next.processingDurationSeconds
// Skip: chatState.regenerate, chatState.setPresentingDocument,
// most of llmManager, onMessageSelection (function/object props)
);
@@ -98,7 +95,6 @@ const AgentMessage = React.memo(function AgentMessage({
onRegenerate,
parentMessage,
processingDurationSeconds,
hideFooter,
}: AgentMessageProps) {
const markdownRef = useRef<HTMLDivElement>(null);
const finalAnswerRef = useRef<HTMLDivElement>(null);
@@ -330,7 +326,7 @@ const AgentMessage = React.memo(function AgentMessage({
</div>
{/* Feedback buttons - only show when streaming and rendering complete */}
{isComplete && !hideFooter && (
{isComplete && (
<MessageToolbar
nodeId={nodeId}
messageId={messageId}

View File

@@ -12,7 +12,6 @@ import {
FileChatDisplay,
Message,
MessageResponseIDInfo,
MultiModelMessageResponseIDInfo,
ResearchType,
RetrievalType,
StreamingError,
@@ -97,7 +96,6 @@ export type PacketType =
| FileChatDisplay
| StreamingError
| MessageResponseIDInfo
| MultiModelMessageResponseIDInfo
| StreamStopInfo
| UserKnowledgeFilePacket
| Packet;
@@ -111,13 +109,6 @@ export type MessageOrigin =
| "slackbot"
| "unknown";
export interface LLMOverride {
model_provider: string;
model_version: string;
temperature?: number;
display_name?: string;
}
export interface SendMessageParams {
message: string;
fileDescriptors?: FileDescriptor[];
@@ -133,8 +124,6 @@ export interface SendMessageParams {
modelProvider?: string;
modelVersion?: string;
temperature?: number;
// Multi-model: send multiple LLM overrides for parallel generation
llmOverrides?: LLMOverride[];
// Origin of the message for telemetry tracking
origin?: MessageOrigin;
// Additional context injected into the LLM call but not stored/shown in chat.
@@ -155,7 +144,6 @@ export async function* sendMessage({
modelProvider,
modelVersion,
temperature,
llmOverrides,
origin,
additionalContext,
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
@@ -177,8 +165,6 @@ export async function* sendMessage({
model_version: modelVersion,
}
: null,
// Multi-model: list of LLM overrides for parallel generation
llm_overrides: llmOverrides ?? null,
// Default to "unknown" for consistency with backend; callers should set explicitly
origin: origin ?? "unknown",
additional_context: additionalContext ?? null,
@@ -202,20 +188,6 @@ export async function* sendMessage({
yield* handleSSEStream<PacketType>(response, signal);
}
export async function setPreferredResponse(
userMessageId: number,
preferredResponseId: number
): Promise<Response> {
return fetch("/api/chat/set-preferred-response", {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
user_message_id: userMessageId,
preferred_response_id: preferredResponseId,
}),
});
}
export async function nameChatSession(chatSessionId: string) {
const response = await fetch("/api/chat/rename-chat-session", {
method: "PUT",
@@ -385,9 +357,6 @@ export function processRawChatHistory(
overridden_model: messageInfo.overridden_model,
packets: packetsForMessage || [],
currentFeedback: messageInfo.current_feedback as FeedbackType | null,
// Multi-model answer generation
preferredResponseId: messageInfo.preferred_response_id ?? null,
modelDisplayName: messageInfo.model_display_name ?? null,
};
messages.set(messageInfo.message_id, message);

View File

@@ -403,7 +403,6 @@ export interface Placement {
turn_index: number;
tab_index?: number; // For parallel tool calls - tools with same turn_index but different tab_index run in parallel
sub_turn_index?: number | null;
model_index?: number | null; // For multi-model answer generation - identifies which model produced this packet
}
// Packet wrapper for streaming objects

View File

@@ -28,7 +28,12 @@ export default function Layout({ children }: LayoutProps) {
<SettingsLayouts.Header icon={SvgSliders} title="Settings" separator />
<SettingsLayouts.Body>
<Section flexDirection="row" alignItems="start" gap={1.5}>
<Section
flexDirection="row"
justifyContent="start"
alignItems="start"
gap={1.5}
>
{/* Left: Tab Navigation */}
<div
data-testid="settings-left-tab-navigation"

View File

@@ -7,8 +7,11 @@ import { processRawChatHistory } from "@/app/app/services/lib";
import { getLatestMessageChain } from "@/app/app/services/messageTree";
import HumanMessage from "@/app/app/message/HumanMessage";
import AgentMessage from "@/app/app/message/messageComponents/AgentMessage";
import { Callout } from "@/components/ui/callout";
import OnyxInitializingLoader from "@/components/OnyxInitializingLoader";
import { Section } from "@/layouts/general-layouts";
import { IllustrationContent } from "@opal/layouts";
import SvgNotFound from "@opal/illustrations/not-found";
import { Button } from "@opal/components";
import { Persona } from "@/app/admin/agents/interfaces";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import PreviewModal from "@/sections/modals/PreviewModal";
@@ -33,12 +36,17 @@ export default function SharedChatDisplay({
if (!chatSession) {
return (
<div className="min-h-full w-full">
<div className="mx-auto w-fit pt-8">
<Callout type="danger" title="Shared Chat Not Found">
Did not find a shared chat with the specified ID.
</Callout>
</div>
<div className="h-full w-full flex flex-col items-center justify-center">
<Section flexDirection="column" alignItems="center" gap={1}>
<IllustrationContent
illustration={SvgNotFound}
title="Shared chat not found"
description="Did not find a shared chat with the specified ID."
/>
<Button href="/app" prominence="secondary">
Start a new chat
</Button>
</Section>
</div>
);
}
@@ -51,12 +59,17 @@ export default function SharedChatDisplay({
if (firstMessage === undefined) {
return (
<div className="min-h-full w-full">
<div className="mx-auto w-fit pt-8">
<Callout type="danger" title="Shared Chat Not Found">
No messages found in shared chat.
</Callout>
</div>
<div className="h-full w-full flex flex-col items-center justify-center">
<Section flexDirection="column" alignItems="center" gap={1}>
<IllustrationContent
illustration={SvgNotFound}
title="Shared chat not found"
description="No messages found in shared chat."
/>
<Button href="/app" prominence="secondary">
Start a new chat
</Button>
</Section>
</div>
);
}

View File

@@ -13,6 +13,7 @@ import {
type KeyboardEvent,
} from "react";
import { useRouter } from "next/navigation";
import { getPastedFilesIfNoText } from "@/lib/clipboard";
import { cn, isImageFile } from "@/lib/utils";
import { Disabled } from "@opal/core";
import {
@@ -230,21 +231,11 @@ const InputBar = memo(
const handlePaste = useCallback(
(event: ClipboardEvent) => {
const items = event.clipboardData?.items;
if (items) {
const pastedFiles: File[] = [];
for (let i = 0; i < items.length; i++) {
const item = items[i];
if (item && item.kind === "file") {
const file = item.getAsFile();
if (file) pastedFiles.push(file);
}
}
if (pastedFiles.length > 0) {
event.preventDefault();
// Context handles session binding internally
uploadFiles(pastedFiles);
}
const pastedFiles = getPastedFilesIfNoText(event.clipboardData);
if (pastedFiles.length > 0) {
event.preventDefault();
// Context handles session binding internally
uploadFiles(pastedFiles);
}
},
[uploadFiles]

View File

@@ -413,7 +413,7 @@ const MemoizedBuildSidebarInner = memo(
return (
<SidebarWrapper folded={folded} onFoldClick={onFoldClick}>
<SidebarBody
actionButtons={
pinnedContent={
<div className="flex flex-col gap-0.5">
{newBuildButton}
{buildConfigurePanel}

View File

@@ -2,7 +2,9 @@
/* Base layers */
--z-base: 0;
--z-content: 1;
--z-settings-header: 8;
/* Settings header must sit above sticky table headers (--z-sticky: 10) so
the page header scrolls over pinned columns without being obscured. */
--z-settings-header: 11;
--z-app-layout: 9;
--z-sticky: 10;

View File

@@ -0,0 +1 @@
export { default } from "@/refresh-pages/admin/GroupsPage/CreateGroupPage";

View File

@@ -17,6 +17,7 @@ import StatsOverlayLoader from "@/components/dev/StatsOverlayLoader";
import AppHealthBanner from "@/sections/AppHealthBanner";
import CustomAnalyticsScript from "@/providers/CustomAnalyticsScript";
import ProductGatingWrapper from "@/providers/ProductGatingWrapper";
import SWRConfigProvider from "@/providers/SWRConfigProvider";
const hankenGrotesk = Hanken_Grotesk({
subsets: ["latin"],
@@ -79,21 +80,23 @@ export default function RootLayout({
<div className="text-text min-h-screen bg-background">
<TooltipProvider>
<PHProvider>
<AppHealthBanner />
<AppProvider>
<DynamicMetadata />
<CustomAnalyticsScript />
<Suspense fallback={null}>
<PostHogPageView />
</Suspense>
<div id={MODAL_ROOT_ID} className="h-screen w-screen">
<ProductGatingWrapper>{children}</ProductGatingWrapper>
</div>
{process.env.NEXT_PUBLIC_POSTHOG_KEY && <WebVitals />}
{process.env.NEXT_PUBLIC_ENABLE_STATS === "true" && (
<StatsOverlayLoader />
)}
</AppProvider>
<SWRConfigProvider>
<AppHealthBanner />
<AppProvider>
<DynamicMetadata />
<CustomAnalyticsScript />
<Suspense fallback={null}>
<PostHogPageView />
</Suspense>
<div id={MODAL_ROOT_ID} className="h-screen w-screen">
<ProductGatingWrapper>{children}</ProductGatingWrapper>
</div>
{process.env.NEXT_PUBLIC_POSTHOG_KEY && <WebVitals />}
{process.env.NEXT_PUBLIC_ENABLE_STATS === "true" && (
<StatsOverlayLoader />
)}
</AppProvider>
</SWRConfigProvider>
</PHProvider>
</TooltipProvider>
</div>

View File

@@ -459,7 +459,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
onResubmit={handleResubmitLastMessage}
deepResearchEnabled={deepResearchEnabled}
anchorNodeId={anchorNodeId}
selectedModels={[]}
/>
</ChatScrollContainer>
</>

View File

@@ -21,7 +21,7 @@ import Text from "@/refresh-components/texts/Text";
import { Section } from "@/layouts/general-layouts";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import { SvgCheck, SvgClock, SvgTag } from "@opal/icons";
import FilterButton from "@/refresh-components/buttons/FilterButton";
import { FilterButton } from "@opal/components";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import useFilter from "@/hooks/useFilter";
import { LineItemButton } from "@opal/components";
@@ -217,7 +217,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
<Popover open={timeFilterOpen} onOpenChange={setTimeFilterOpen}>
<Popover.Trigger asChild>
<FilterButton
leftIcon={SvgClock}
icon={SvgClock}
active={!!timeFilter}
onClear={() => {
setTimeFilter(null);
@@ -253,7 +253,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
<Popover open={tagFilterOpen} onOpenChange={setTagFilterOpen}>
<Popover.Trigger asChild>
<FilterButton
leftIcon={SvgTag}
icon={SvgTag}
active={selectedTags.length > 0}
onClear={() => {
setSelectedTags([]);

View File

@@ -1,7 +1,15 @@
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
interface AuthTypeAPIResponse {
auth_type: string;
requires_verification: boolean;
anonymous_user_enabled: boolean | null;
password_min_length: number;
has_users: boolean;
oauth_enabled: boolean;
}
export interface AuthTypeMetadata {
authType: AuthType;
autoRedirect: boolean;
@@ -22,6 +30,24 @@ const DEFAULT_AUTH_TYPE_METADATA: AuthTypeMetadata = {
oauthEnabled: false,
};
async function fetchAuthTypeMetadata(url: string): Promise<AuthTypeMetadata> {
const res = await fetch(url);
if (!res.ok) throw new Error("Failed to fetch auth type metadata");
const data: AuthTypeAPIResponse = await res.json();
const authType = NEXT_PUBLIC_CLOUD_ENABLED
? AuthType.CLOUD
: (data.auth_type as AuthType);
return {
authType,
autoRedirect: authType === AuthType.OIDC || authType === AuthType.SAML,
requiresVerification: data.requires_verification,
anonymousUserEnabled: data.anonymous_user_enabled,
passwordMinLength: data.password_min_length,
hasUsers: data.has_users,
oauthEnabled: data.oauth_enabled,
};
}
export function useAuthTypeMetadata(): {
authTypeMetadata: AuthTypeMetadata;
isLoading: boolean;
@@ -29,7 +55,7 @@ export function useAuthTypeMetadata(): {
} {
const { data, error, isLoading } = useSWR<AuthTypeMetadata>(
"/api/auth/type",
errorHandlingFetcher,
fetchAuthTypeMetadata,
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
@@ -37,14 +63,6 @@ export function useAuthTypeMetadata(): {
}
);
if (NEXT_PUBLIC_CLOUD_ENABLED && data) {
return {
authTypeMetadata: { ...data, authType: AuthType.CLOUD },
isLoading,
error,
};
}
return {
authTypeMetadata: data ?? DEFAULT_AUTH_TYPE_METADATA,
isLoading,

View File

@@ -3,7 +3,6 @@
import {
buildChatUrl,
getAvailableContextTokens,
LLMOverride,
nameChatSession,
updateLlmOverrideForChatSession,
} from "@/app/app/services/lib";
@@ -34,7 +33,6 @@ import {
FileDescriptor,
Message,
MessageResponseIDInfo,
MultiModelMessageResponseIDInfo,
RegenerationState,
RetrievalType,
StreamingError,
@@ -72,7 +70,6 @@ import {
} from "@/app/app/stores/useChatSessionStore";
import { Packet, MessageStart } from "@/app/app/services/streamingModels";
import useAgentPreferences from "@/hooks/useAgentPreferences";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
import { useForcedTools } from "@/lib/hooks/useForcedTools";
import { ProjectFile, useProjectsContext } from "@/providers/ProjectsContext";
import { useAppParams } from "@/hooks/appNavigation";
@@ -97,8 +94,6 @@ export interface OnSubmitProps {
regenerationRequest?: RegenerationRequest | null;
// Additional context injected into the LLM call but not stored/shown in chat.
additionalContext?: string;
// Multi-model chat: up to 3 models selected for parallel comparison.
selectedModels?: SelectedModel[];
}
interface RegenerationRequest {
@@ -375,10 +370,7 @@ export default function useChatController({
modelOverride,
regenerationRequest,
additionalContext,
selectedModels,
}: OnSubmitProps) => {
// Check if this is multi-model mode (2 or 3 models selected)
const isMultiModelMode = selectedModels && selectedModels.length >= 2;
const projectId = params(SEARCH_PARAM_NAMES.PROJECT_ID);
{
const params = new URLSearchParams(searchParams?.toString() || "");
@@ -609,7 +601,6 @@ export default function useChatController({
// immediately reflects the user message
let initialUserNode: Message;
let initialAgentNode: Message;
let initialAssistantNodes: Message[] = [];
if (regenerationRequest) {
// For regeneration: keep the existing user message, only create new agent
@@ -632,30 +623,12 @@ export default function useChatController({
);
initialUserNode = result.initialUserNode;
initialAgentNode = result.initialAgentNode;
// In multi-model mode, create N assistant nodes (one per selected model)
if (isMultiModelMode && selectedModels) {
for (let i = 0; i < selectedModels.length; i++) {
initialAssistantNodes.push(
buildEmptyMessage({
messageType: "assistant",
parentNodeId: initialUserNode.nodeId,
nodeIdOffset: i + 1,
})
);
}
}
}
// make messages appear + clear input bar
let messagesToUpsert: Message[];
if (regenerationRequest) {
messagesToUpsert = [initialAgentNode];
} else if (isMultiModelMode) {
messagesToUpsert = [initialUserNode, ...initialAssistantNodes];
} else {
messagesToUpsert = [initialUserNode, initialAgentNode];
}
const messagesToUpsert = regenerationRequest
? [initialAgentNode] // Only upsert the new agent for regeneration
: [initialUserNode, initialAgentNode]; // Upsert both for normal/edit flow
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: messagesToUpsert,
completeMessageTreeOverride: currentMessageTreeLocal,
@@ -689,24 +662,6 @@ export default function useChatController({
let newUserMessageId: number | null = null;
let newAgentMessageId: number | null = null;
// Multi-model mode state tracking (dynamically sized based on selected models)
const numModels = selectedModels?.length ?? 0;
let newAssistantMessageIds: (number | null)[] = isMultiModelMode
? Array(numModels).fill(null)
: [];
let packetsPerModel: Packet[][] = isMultiModelMode
? Array.from({ length: numModels }, () => [])
: [];
let modelDisplayNames: string[] = isMultiModelMode
? selectedModels?.map((m) => m.displayName) ?? []
: [];
let documentsPerModel: OnyxDocument[][] = isMultiModelMode
? Array.from({ length: numModels }, () => [])
: [];
let citationsPerModel: (CitationMap | null)[] = isMultiModelMode
? Array(numModels).fill(null)
: [];
try {
const lastSuccessfulMessageId = getLastSuccessfulMessageId(
currentMessageTreeLocal
@@ -755,15 +710,13 @@ export default function useChatController({
filterManager.timeRange,
filterManager.selectedTags
),
modelProvider: isMultiModelMode
? undefined
: modelOverride?.name || llmManager.currentLlm.name || undefined,
modelVersion: isMultiModelMode
? undefined
: modelOverride?.modelName ||
llmManager.currentLlm.modelName ||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined,
modelProvider:
modelOverride?.name || llmManager.currentLlm.name || undefined,
modelVersion:
modelOverride?.modelName ||
llmManager.currentLlm.modelName ||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined,
temperature: llmManager.temperature || undefined,
deepResearch,
enabledToolIds:
@@ -775,12 +728,6 @@ export default function useChatController({
forcedToolId: effectiveForcedToolId,
origin: messageOrigin,
additionalContext,
llmOverrides: isMultiModelMode
? selectedModels!.map((model) => ({
model_provider: model.name,
model_version: model.modelName,
}))
: undefined,
});
const delay = (ms: number) => {
@@ -833,26 +780,6 @@ export default function useChatController({
.reserved_assistant_message_id;
}
// Multi-model: handle reserved IDs for N parallel model responses
if (
isMultiModelMode &&
Object.hasOwn(packet, "reserved_assistant_message_ids") &&
Array.isArray(
(packet as MultiModelMessageResponseIDInfo)
.reserved_assistant_message_ids
)
) {
const multiPacket = packet as MultiModelMessageResponseIDInfo;
newAssistantMessageIds =
multiPacket.reserved_assistant_message_ids;
newUserMessageId =
multiPacket.user_message_id ?? newUserMessageId;
// Capture backend model names for display on reload
if (multiPacket.model_names?.length) {
modelDisplayNames = multiPacket.model_names;
}
}
if (Object.hasOwn(packet, "user_files")) {
const userFiles = (packet as UserKnowledgeFilePacket).user_files;
// Ensure files are unique by id
@@ -896,73 +823,32 @@ export default function useChatController({
updateCanContinue(true, frozenSessionId);
}
} else if (Object.hasOwn(packet, "obj")) {
const typedPacket = packet as Packet;
packets.push(packet as Packet);
packetsVersion++;
// In multi-model mode, route packets by model_index
if (isMultiModelMode) {
const modelIndex = typedPacket.placement?.model_index ?? 0;
if (
modelIndex >= 0 &&
modelIndex < packetsPerModel.length &&
packetsPerModel[modelIndex]
) {
packetsPerModel[modelIndex] = [
...packetsPerModel[modelIndex]!,
typedPacket,
];
// Check if the packet contains document information
const packetObj = (packet as Packet).obj;
const packetObj = typedPacket.obj;
if (packetObj.type === "citation_info") {
const citationInfo = packetObj as {
type: "citation_info";
citation_number: number;
document_id: string;
};
citationsPerModel[modelIndex] = {
...(citationsPerModel[modelIndex] || {}),
[citationInfo.citation_number]: citationInfo.document_id,
};
} else if (packetObj.type === "message_start") {
const messageStart = packetObj as MessageStart;
if (messageStart.final_documents) {
documentsPerModel[modelIndex] =
messageStart.final_documents;
if (modelIndex === 0 && initialAssistantNodes[0]) {
updateSelectedNodeForDocDisplay(
frozenSessionId,
initialAssistantNodes[0].nodeId
);
}
}
}
}
} else {
// Single model mode
packets.push(typedPacket);
packetsVersion++;
const packetObj = typedPacket.obj;
if (packetObj.type === "citation_info") {
const citationInfo = packetObj as {
type: "citation_info";
citation_number: number;
document_id: string;
};
citations = {
...(citations || {}),
[citationInfo.citation_number]: citationInfo.document_id,
};
} else if (packetObj.type === "message_start") {
const messageStart = packetObj as MessageStart;
if (messageStart.final_documents) {
documents = messageStart.final_documents;
updateSelectedNodeForDocDisplay(
frozenSessionId,
initialAgentNode.nodeId
);
}
if (packetObj.type === "citation_info") {
// Individual citation packet from backend streaming
const citationInfo = packetObj as {
type: "citation_info";
citation_number: number;
document_id: string;
};
// Incrementally build citations map
citations = {
...(citations || {}),
[citationInfo.citation_number]: citationInfo.document_id,
};
} else if (packetObj.type === "message_start") {
const messageStart = packetObj as MessageStart;
if (messageStart.final_documents) {
documents = messageStart.final_documents;
updateSelectedNodeForDocDisplay(
frozenSessionId,
initialAgentNode.nodeId
);
}
}
} else {
@@ -974,48 +860,8 @@ export default function useChatController({
parentMessage =
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
// Build messages to upsert based on mode
let messagesToUpsertInLoop: Message[];
if (isMultiModelMode) {
// Multi-model mode: update user node + all N assistant nodes
const updatedUserNode = {
...initialUserNode,
messageId: newUserMessageId ?? undefined,
files: files,
};
const updatedAssistantNodes = initialAssistantNodes.map(
(node, idx) => ({
...node,
messageId: newAssistantMessageIds[idx] ?? undefined,
message: "",
type: "assistant" as const,
retrievalType,
query: query,
documents: documentsPerModel[idx] || [],
citations: citationsPerModel[idx] || {},
files: [] as FileDescriptor[],
toolCall: null,
stackTrace: null,
overridden_model: selectedModels?.[idx]?.displayName,
modelDisplayName:
modelDisplayNames[idx] ||
selectedModels?.[idx]?.displayName ||
null,
stopReason: stopReason,
packets: packetsPerModel[idx] || [],
packetCount: packetsPerModel[idx]?.length || 0,
})
);
messagesToUpsertInLoop = [
updatedUserNode,
...updatedAssistantNodes,
];
} else {
// Single model mode (existing logic)
messagesToUpsertInLoop = [
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: [
{
...initialUserNode,
messageId: newUserMessageId ?? undefined,
@@ -1048,11 +894,8 @@ export default function useChatController({
: undefined;
})(),
},
];
}
currentMessageTreeLocal = upsertToCompleteMessageTree({
messages: messagesToUpsertInLoop,
],
// Pass the latest map state
completeMessageTreeOverride: currentMessageTreeLocal,
chatSessionId: frozenSessionId!,
});

View File

@@ -61,6 +61,11 @@ interface UseChatSessionControllerProps {
}) => Promise<void>;
}
export type SessionFetchError = {
type: "not_found" | "access_denied" | "unknown";
detail: string;
} | null;
export default function useChatSessionController({
existingChatSessionId,
searchParams,
@@ -80,6 +85,8 @@ export default function useChatSessionController({
const [currentSessionFileTokenCount, setCurrentSessionFileTokenCount] =
useState<number>(0);
const [projectFiles, setProjectFiles] = useState<ProjectFile[]>([]);
const [sessionFetchError, setSessionFetchError] =
useState<SessionFetchError>(null);
// Store actions
const updateSessionAndMessageTree = useChatSessionStore(
(state) => state.updateSessionAndMessageTree
@@ -151,6 +158,8 @@ export default function useChatSessionController({
}
async function initialSessionFetch() {
setSessionFetchError(null);
if (existingChatSessionId === null) {
// Clear the current session in the store to show intro messages
setCurrentSession(null);
@@ -178,9 +187,42 @@ export default function useChatSessionController({
setCurrentSession(existingChatSessionId);
setIsFetchingChatMessages(existingChatSessionId, true);
const response = await fetch(
`/api/chat/get-chat-session/${existingChatSessionId}`
);
let response: Response;
try {
response = await fetch(
`/api/chat/get-chat-session/${existingChatSessionId}`
);
} catch (error) {
setIsFetchingChatMessages(existingChatSessionId, false);
console.error("Failed to fetch chat session", {
chatSessionId: existingChatSessionId,
error,
});
setSessionFetchError({
type: "unknown",
detail: "Failed to load chat session. Please check your connection.",
});
return;
}
if (!response.ok) {
setIsFetchingChatMessages(existingChatSessionId, false);
let detail = "An unexpected error occurred.";
try {
const errorBody = await response.json();
detail = errorBody.detail || detail;
} catch {
// ignore parse errors
}
const type =
response.status === 404
? "not_found"
: response.status === 403
? "access_denied"
: "unknown";
setSessionFetchError({ type, detail });
return;
}
const session = await response.json();
const chatSession = session as BackendChatSession;
@@ -356,5 +398,6 @@ export default function useChatSessionController({
currentSessionFileTokenCount,
onMessageSelection,
projectFiles,
sessionFetchError,
};
}

View File

@@ -36,7 +36,11 @@ export function useMemoryManager({
setLocalMemories((prev) => {
const emptyNewItems = prev.filter((m) => m.isNew && !m.content.trim());
return [...emptyNewItems, ...existingMemories];
const availableSlots = MAX_MEMORY_COUNT - existingMemories.length;
return [
...emptyNewItems.slice(0, Math.max(0, availableSlots)),
...existingMemories,
];
});
initialMemoriesRef.current = memories;
}, [memories]);

View File

@@ -1,232 +0,0 @@
import { renderHook, act } from "@testing-library/react";
import useMultiModelChat from "@/hooks/useMultiModelChat";
import { LlmManager } from "@/lib/hooks";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
// Mock buildLlmOptions — hook uses it internally for initialization.
// Tests here focus on CRUD operations, not the initialization side-effect.
jest.mock("@/refresh-components/popovers/LLMPopover", () => ({
buildLlmOptions: jest.fn(() => []),
}));
const makeLlmManager = (): LlmManager =>
({
llmProviders: [],
currentLlm: { modelName: null, provider: null },
isLoadingProviders: false,
}) as unknown as LlmManager;
const makeModel = (provider: string, modelName: string): SelectedModel => ({
name: provider,
provider,
modelName,
displayName: `${provider}/${modelName}`,
});
const GPT4 = makeModel("openai", "gpt-4");
const CLAUDE = makeModel("anthropic", "claude-opus-4-6");
const GEMINI = makeModel("google", "gemini-pro");
const GPT4_TURBO = makeModel("openai", "gpt-4-turbo");
// ---------------------------------------------------------------------------
// addModel
// ---------------------------------------------------------------------------
describe("addModel", () => {
it("adds a model to an empty selection", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
expect(result.current.selectedModels).toHaveLength(1);
expect(result.current.selectedModels[0]).toEqual(GPT4);
});
it("does not add a duplicate model", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(GPT4); // duplicate
});
expect(result.current.selectedModels).toHaveLength(1);
});
it("enforces MAX_MODELS (3) cap", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
result.current.addModel(GEMINI);
result.current.addModel(GPT4_TURBO); // should be ignored
});
expect(result.current.selectedModels).toHaveLength(3);
});
});
// ---------------------------------------------------------------------------
// removeModel
// ---------------------------------------------------------------------------
describe("removeModel", () => {
it("removes a model by index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.removeModel(0); // remove GPT4
});
expect(result.current.selectedModels).toHaveLength(1);
expect(result.current.selectedModels[0]).toEqual(CLAUDE);
});
it("handles out-of-range index gracefully", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
act(() => {
result.current.removeModel(99); // no-op
});
expect(result.current.selectedModels).toHaveLength(1);
});
});
// ---------------------------------------------------------------------------
// replaceModel
// ---------------------------------------------------------------------------
describe("replaceModel", () => {
it("replaces the model at the given index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.replaceModel(0, GEMINI);
});
expect(result.current.selectedModels[0]).toEqual(GEMINI);
expect(result.current.selectedModels[1]).toEqual(CLAUDE);
});
it("does not replace with a model already selected at another index", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.replaceModel(0, CLAUDE); // CLAUDE is already at index 1
});
// Should be a no-op — GPT4 stays at index 0
expect(result.current.selectedModels[0]).toEqual(GPT4);
});
});
// ---------------------------------------------------------------------------
// isMultiModelActive
// ---------------------------------------------------------------------------
describe("isMultiModelActive", () => {
it("is false with zero models", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
expect(result.current.isMultiModelActive).toBe(false);
});
it("is false with exactly one model", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
});
expect(result.current.isMultiModelActive).toBe(false);
});
it("is true with two or more models", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
expect(result.current.isMultiModelActive).toBe(true);
});
});
// ---------------------------------------------------------------------------
// buildLlmOverrides
// ---------------------------------------------------------------------------
describe("buildLlmOverrides", () => {
it("returns empty array when no models selected", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
expect(result.current.buildLlmOverrides()).toEqual([]);
});
it("maps selectedModels to LLMOverride format", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
const overrides = result.current.buildLlmOverrides();
expect(overrides).toHaveLength(2);
expect(overrides[0]).toEqual({
model_provider: "openai",
model_version: "gpt-4",
display_name: "openai/gpt-4",
});
expect(overrides[1]).toEqual({
model_provider: "anthropic",
model_version: "claude-opus-4-6",
display_name: "anthropic/claude-opus-4-6",
});
});
});
// ---------------------------------------------------------------------------
// clearModels
// ---------------------------------------------------------------------------
describe("clearModels", () => {
it("empties the selection", () => {
const { result } = renderHook(() => useMultiModelChat(makeLlmManager()));
act(() => {
result.current.addModel(GPT4);
result.current.addModel(CLAUDE);
});
act(() => {
result.current.clearModels();
});
expect(result.current.selectedModels).toHaveLength(0);
expect(result.current.isMultiModelActive).toBe(false);
});
});

View File

@@ -1,191 +0,0 @@
"use client";
import { useState, useCallback, useEffect, useMemo } from "react";
import { SelectedModel } from "@/refresh-components/popovers/ModelSelector";
import { LLMOverride } from "@/app/app/services/lib";
import { LlmManager } from "@/lib/hooks";
import { buildLlmOptions } from "@/refresh-components/popovers/LLMPopover";
const MAX_MODELS = 3;
export interface UseMultiModelChatReturn {
/** Currently selected models for multi-model comparison. */
selectedModels: SelectedModel[];
/** Whether multi-model mode is active (>1 model selected). */
isMultiModelActive: boolean;
/** Add a model to the selection. */
addModel: (model: SelectedModel) => void;
/** Remove a model by index. */
removeModel: (index: number) => void;
/** Replace a model at a specific index with a new one. */
replaceModel: (index: number, model: SelectedModel) => void;
/** Clear all selected models. */
clearModels: () => void;
/** Build the LLMOverride[] array from selectedModels. */
buildLlmOverrides: () => LLMOverride[];
/**
* Restore multi-model selection from model version strings (e.g. from chat history).
* Matches against available llmOptions to reconstruct full SelectedModel objects.
*/
restoreFromModelNames: (modelNames: string[]) => void;
/**
* Switch to a single model by name (after user picks a preferred response).
* Matches against llmOptions to find the full SelectedModel.
*/
selectSingleModel: (modelName: string) => void;
}
export default function useMultiModelChat(
llmManager: LlmManager
): UseMultiModelChatReturn {
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([]);
const [defaultInitialized, setDefaultInitialized] = useState(false);
// Initialize with the default model from llmManager once providers load
const llmOptions = useMemo(
() =>
llmManager.llmProviders ? buildLlmOptions(llmManager.llmProviders) : [],
[llmManager.llmProviders]
);
useEffect(() => {
if (defaultInitialized) return;
if (llmOptions.length === 0) return;
const { currentLlm } = llmManager;
// Don't initialize if currentLlm hasn't loaded yet
if (!currentLlm.modelName) return;
const match = llmOptions.find(
(opt) =>
opt.provider === currentLlm.provider &&
opt.modelName === currentLlm.modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
setDefaultInitialized(true);
}
}, [llmOptions, llmManager.currentLlm, defaultInitialized, llmManager]);
const isMultiModelActive = selectedModels.length > 1;
const addModel = useCallback((model: SelectedModel) => {
setSelectedModels((prev) => {
if (prev.length >= MAX_MODELS) return prev;
if (
prev.some(
(m) =>
m.provider === model.provider && m.modelName === model.modelName
)
) {
return prev;
}
return [...prev, model];
});
}, []);
const removeModel = useCallback((index: number) => {
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
}, []);
const replaceModel = useCallback((index: number, model: SelectedModel) => {
setSelectedModels((prev) => {
// Don't replace with a model that's already selected elsewhere
if (
prev.some(
(m, i) =>
i !== index &&
m.provider === model.provider &&
m.modelName === model.modelName
)
) {
return prev;
}
const next = [...prev];
next[index] = model;
return next;
});
}, []);
const clearModels = useCallback(() => {
setSelectedModels([]);
}, []);
const restoreFromModelNames = useCallback(
(modelNames: string[]) => {
if (modelNames.length < 2 || llmOptions.length === 0) return;
const restored: SelectedModel[] = [];
for (const name of modelNames) {
// Try matching by modelName (raw version string like "claude-opus-4-6")
// or by displayName (friendly name like "Claude Opus 4.6")
const match = llmOptions.find(
(opt) =>
opt.modelName === name ||
opt.displayName === name ||
opt.name === name
);
if (match) {
restored.push({
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
});
}
}
if (restored.length >= 2) {
setSelectedModels(restored);
setDefaultInitialized(true);
}
},
[llmOptions]
);
const selectSingleModel = useCallback(
(modelName: string) => {
if (llmOptions.length === 0) return;
const match = llmOptions.find(
(opt) =>
opt.modelName === modelName ||
opt.displayName === modelName ||
opt.name === modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
}
},
[llmOptions]
);
const buildLlmOverrides = useCallback((): LLMOverride[] => {
return selectedModels.map((m) => ({
model_provider: m.name,
model_version: m.modelName,
display_name: m.displayName,
}));
}, [selectedModels]);
return {
selectedModels,
isMultiModelActive,
addModel,
removeModel,
replaceModel,
clearModels,
buildLlmOverrides,
restoreFromModelNames,
selectSingleModel,
};
}

View File

@@ -123,6 +123,9 @@ export interface LLMProviderFormProps {
open?: boolean;
onOpenChange?: (open: boolean) => void;
/** The current default model name for this provider (from the global default). */
defaultModelName?: string;
// Onboarding-specific (only when variant === "onboarding")
onboardingState?: OnboardingState;
onboardingActions?: OnboardingActions;

View File

@@ -0,0 +1,89 @@
import { getPastedFilesIfNoText } from "./clipboard";
type MockClipboardData = Parameters<typeof getPastedFilesIfNoText>[0];
function makeClipboardData({
textPlain = "",
text = "",
files = [],
}: {
textPlain?: string;
text?: string;
files?: File[];
}): MockClipboardData {
return {
items: files.map((file) => ({
kind: "file",
getAsFile: () => file,
})),
getData: (format: string) => {
if (format === "text/plain") {
return textPlain;
}
if (format === "text") {
return text;
}
return "";
},
};
}
describe("getPastedFilesIfNoText", () => {
it("prefers plain text over pasted files when both are present", () => {
const imageFile = new File(["slide preview"], "slide.png", {
type: "image/png",
});
expect(
getPastedFilesIfNoText(
makeClipboardData({
textPlain: "Welcome to PowerPoint for Mac",
files: [imageFile],
})
)
).toEqual([]);
});
it("falls back to text data when text/plain is empty", () => {
const imageFile = new File(["slide preview"], "slide.png", {
type: "image/png",
});
expect(
getPastedFilesIfNoText(
makeClipboardData({
text: "Welcome to PowerPoint for Mac",
files: [imageFile],
})
)
).toEqual([]);
});
it("still returns files for image-only pastes", () => {
const imageFile = new File(["slide preview"], "slide.png", {
type: "image/png",
});
expect(
getPastedFilesIfNoText(makeClipboardData({ files: [imageFile] }))
).toEqual([imageFile]);
});
it("ignores whitespace-only text and keeps file pastes working", () => {
const imageFile = new File(["slide preview"], "slide.png", {
type: "image/png",
});
expect(
getPastedFilesIfNoText(
makeClipboardData({
textPlain: " ",
text: "\n",
files: [imageFile],
})
)
).toEqual([imageFile]);
});
});

52
web/src/lib/clipboard.ts Normal file
View File

@@ -0,0 +1,52 @@
type ClipboardFileItem = {
kind: string;
getAsFile: () => File | null;
};
type ClipboardDataLike = {
items?: ArrayLike<ClipboardFileItem> | null;
getData: (format: string) => string;
};
function getClipboardText(
clipboardData: ClipboardDataLike,
format: "text/plain" | "text"
): string {
try {
return clipboardData.getData(format);
} catch {
return "";
}
}
export function getPastedFilesIfNoText(
clipboardData?: ClipboardDataLike | null
): File[] {
if (!clipboardData) {
return [];
}
const plainText = getClipboardText(clipboardData, "text/plain").trim();
const fallbackText = getClipboardText(clipboardData, "text").trim();
// Apps like PowerPoint on macOS can place both rendered image data and the
// original text on the clipboard. Prefer letting the textarea consume text.
if (plainText || fallbackText || !clipboardData.items) {
return [];
}
const pastedFiles: File[] = [];
for (let i = 0; i < clipboardData.items.length; i++) {
const item = clipboardData.items[i];
if (item?.kind !== "file") {
continue;
}
const file = item.getAsFile();
if (file) {
pastedFiles.push(file);
}
}
return pastedFiles;
}

View File

@@ -127,8 +127,7 @@ export const DESKTOP_SMALL_BREAKPOINT_PX = 912;
export const DESKTOP_MEDIUM_BREAKPOINT_PX = 1232;
export const DEFAULT_AVATAR_SIZE_PX = 18;
export const HORIZON_DISTANCE_PX = 800;
export const LOGO_FOLDED_SIZE_PX = 24;
export const LOGO_UNFOLDED_SIZE_PX = 88;
export const DEFAULT_LOGO_SIZE_PX = 24;
export const DEFAULT_CONTEXT_TOKENS = 120_000;
export const MAX_CHUNKS_FED_TO_CHAT = 25;

View File

@@ -19,6 +19,29 @@ const DEFAULT_AUTH_ERROR_MSG =
const DEFAULT_ERROR_MSG = "An error occurred while fetching the data.";
/**
* SWR `onErrorRetry` callback that suppresses automatic retries for
* authentication errors (401/403). Pass this to any SWR hook whose endpoint
* requires auth so that unauthenticated pages don't spam the backend.
*/
export const skipRetryOnAuthError: NonNullable<
import("swr").SWRConfiguration["onErrorRetry"]
> = (error, _key, _config, revalidate, { retryCount }) => {
if (
error instanceof FetchError &&
(error.status === 401 || error.status === 403)
)
return;
// For non-auth errors, retry with exponential backoff
if (
_config.errorRetryCount !== undefined &&
retryCount >= _config.errorRetryCount
)
return;
const delay = Math.min(2000 * 2 ** retryCount, 30000);
setTimeout(() => revalidate({ retryCount }), delay);
};
export const errorHandlingFetcher = async <T>(url: string): Promise<T> => {
const res = await fetch(url);

View File

@@ -0,0 +1,16 @@
"use client";
import { SWRConfig } from "swr";
import { skipRetryOnAuthError } from "@/lib/fetcher";
export default function SWRConfigProvider({
children,
}: {
children: React.ReactNode;
}) {
return (
<SWRConfig value={{ onErrorRetry: skipRetryOnAuthError }}>
{children}
</SWRConfig>
);
}

Some files were not shown because too many files have changed in this diff Show More