mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-14 19:32:53 +00:00
Compare commits
1 Commits
v3.2.3
...
jamison/fe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c709a4fb6 |
@@ -9,6 +9,7 @@ repos:
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
@@ -17,7 +18,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
@@ -30,7 +31,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
@@ -43,7 +44,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
@@ -56,7 +57,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
|
||||
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@@ -531,7 +531,8 @@
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"sync"
|
||||
"sync",
|
||||
"--all-extras"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
|
||||
@@ -117,7 +117,7 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required Python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
@@ -13,7 +13,6 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -108,13 +107,12 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role
|
||||
and the anonymous system user).
|
||||
|
||||
Only human accounts count toward seat limits.
|
||||
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
|
||||
anonymous system user are excluded. BOT (Slack users) ARE counted
|
||||
because they represent real humans and get upgraded to STANDARD
|
||||
when they log in via web.
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
@@ -131,7 +129,6 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -11,8 +11,6 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -24,7 +22,6 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -68,25 +65,12 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -225,37 +209,12 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
@@ -54,21 +53,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
|
||||
|
||||
Payloads written before the fix for serializing Google credentials into
|
||||
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
|
||||
Once every install has re-uploaded their Google credentials the legacy
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
@@ -178,13 +162,12 @@ def build_service_account_creds(
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
credential_json = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
)
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
@@ -205,12 +188,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**creds)
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
@@ -218,14 +201,10 @@ def upsert_google_app_cred(
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY,
|
||||
app_credentials.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
|
||||
)
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -241,14 +220,12 @@ def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**creds)
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
@@ -257,14 +234,12 @@ def upsert_service_account_key(
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
service_account_key.json(),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -60,10 +60,8 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -257,13 +255,15 @@ def _bulk_fetch_request(
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def _bulk_fetch_batch(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
|
||||
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
return _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
@@ -277,25 +277,12 @@ def _bulk_fetch_batch(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
|
||||
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
raw_issues: list[dict[str, Any]] = []
|
||||
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
|
||||
try:
|
||||
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -7,14 +6,6 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -50,6 +49,7 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,6 +58,7 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -420,94 +421,6 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -519,6 +432,7 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -748,6 +662,7 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -780,37 +695,62 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
"""Build the thread text from messages."""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
for msg in replies:
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -1036,16 +976,7 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -1062,16 +993,8 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1087,7 +1010,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
)
|
||||
]
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -639,38 +638,12 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
) -> list[str]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -695,15 +668,6 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -720,9 +684,7 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -740,8 +702,7 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
search_queries = [
|
||||
return [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -96,32 +96,6 @@ def _truncate_description(description: str | None, max_length: int = 500) -> str
|
||||
return description[: max_length - 3] + "..."
|
||||
|
||||
|
||||
# TODO: Replace mask-comparison approach with an explicit Unset sentinel from the
|
||||
# frontend indicating whether each credential field was actually modified. The current
|
||||
# approach is brittle (e.g. short credentials produce a fixed-length mask that could
|
||||
# collide) and mutates request values, which is surprising. The frontend should signal
|
||||
# "unchanged" vs "new value" directly rather than relying on masked-string equality.
|
||||
def _restore_masked_oauth_credentials(
|
||||
request_client_id: str | None,
|
||||
request_client_secret: str | None,
|
||||
existing_client: OAuthClientInformationFull,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""If the frontend sent back masked credentials, restore the real stored values."""
|
||||
if (
|
||||
request_client_id
|
||||
and existing_client.client_id
|
||||
and request_client_id == mask_string(existing_client.client_id)
|
||||
):
|
||||
request_client_id = existing_client.client_id
|
||||
if (
|
||||
request_client_secret
|
||||
and existing_client.client_secret
|
||||
and request_client_secret == mask_string(existing_client.client_secret)
|
||||
):
|
||||
request_client_secret = existing_client.client_secret
|
||||
return request_client_id, request_client_secret
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mcp")
|
||||
admin_router = APIRouter(prefix="/admin/mcp")
|
||||
STATE_TTL_SECONDS = 60 * 5 # 5 minutes
|
||||
@@ -418,26 +392,6 @@ async def _connect_oauth(
|
||||
detail=f"Server was configured with authentication type {auth_type_str}",
|
||||
)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so we don't overwrite them with masks.
|
||||
if mcp_server.admin_connection_config:
|
||||
existing_data = extract_connection_data(
|
||||
mcp_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
existing_client_raw = existing_data.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if existing_client_raw:
|
||||
existing_client = OAuthClientInformationFull.model_validate(
|
||||
existing_client_raw
|
||||
)
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
existing_client,
|
||||
)
|
||||
|
||||
# Create admin config with client info if provided
|
||||
config_data = MCPConnectionData(headers={})
|
||||
if request.oauth_client_id and request.oauth_client_secret:
|
||||
@@ -1402,19 +1356,6 @@ def _upsert_mcp_server(
|
||||
if client_info_raw:
|
||||
client_info = OAuthClientInformationFull.model_validate(client_info_raw)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so the comparison below sees no change
|
||||
# and the credentials aren't overwritten with masked strings.
|
||||
if client_info and request.auth_type == MCPAuthenticationType.OAUTH:
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
client_info,
|
||||
)
|
||||
|
||||
changing_connection_config = (
|
||||
not mcp_server.admin_connection_config
|
||||
or (
|
||||
|
||||
@@ -47,6 +47,8 @@ from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import get_bedrock_token_limit
|
||||
from onyx.llm.utils import get_llm_contextual_cost
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import litellm_thinks_model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
@@ -62,6 +64,8 @@ from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import BifrostFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BifrostModelsRequest
|
||||
from onyx.server.manage.llm.models import CustomProviderModelResponse
|
||||
from onyx.server.manage.llm.models import CustomProviderModelsRequest
|
||||
from onyx.server.manage.llm.models import CustomProviderOption
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
@@ -111,43 +115,6 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _resolve_api_key(
|
||||
api_key: str | None,
|
||||
provider_name: str | None,
|
||||
api_base: str | None,
|
||||
db_session: Session,
|
||||
) -> str | None:
|
||||
"""Return the real API key for model-fetch endpoints.
|
||||
|
||||
When editing an existing provider the form value is masked (e.g.
|
||||
``sk-a****b1c2``). If *provider_name* is supplied we can look up
|
||||
the unmasked key from the database so the external request succeeds.
|
||||
|
||||
The stored key is only returned when the request's *api_base*
|
||||
matches the value stored in the database.
|
||||
"""
|
||||
if not provider_name:
|
||||
return api_key
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.api_key:
|
||||
# Normalise both URLs before comparing so trailing-slash
|
||||
# differences don't cause a false mismatch.
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
request_base = (api_base or "").strip().rstrip("/")
|
||||
if stored_base != request_base:
|
||||
return api_key
|
||||
|
||||
stored_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
# Only resolve when the incoming value is the masked form of the
|
||||
# stored key — i.e. the user hasn't typed a new key.
|
||||
if api_key and api_key == _mask_string(stored_key):
|
||||
return stored_key
|
||||
return api_key
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
@@ -313,6 +280,158 @@ def fetch_custom_provider_names(
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/custom/available-models")
|
||||
def fetch_custom_provider_models(
|
||||
request: CustomProviderModelsRequest,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> list[CustomProviderModelResponse]:
|
||||
"""Fetch models for a custom provider.
|
||||
|
||||
When ``api_base`` is provided the endpoint hits the provider's
|
||||
OpenAI-compatible ``/v1/models`` (or ``/{api_version}/models``) to
|
||||
discover live models. Otherwise it falls back to the static list
|
||||
that LiteLLM ships for the given provider slug.
|
||||
|
||||
In both cases the response is enriched with metadata from LiteLLM
|
||||
(display name, max input tokens, vision support) when available.
|
||||
"""
|
||||
if request.api_base:
|
||||
return _fetch_custom_models_from_api(
|
||||
provider=request.provider,
|
||||
api_base=request.api_base,
|
||||
api_key=request.api_key,
|
||||
api_version=request.api_version,
|
||||
)
|
||||
|
||||
return _fetch_custom_models_from_litellm(request.provider)
|
||||
|
||||
|
||||
def _enrich_custom_model(
|
||||
name: str,
|
||||
provider: str,
|
||||
*,
|
||||
api_display_name: str | None = None,
|
||||
api_max_input_tokens: int | None = None,
|
||||
api_supports_image_input: bool | None = None,
|
||||
) -> CustomProviderModelResponse:
|
||||
"""Build a ``CustomProviderModelResponse`` enriched with LiteLLM metadata.
|
||||
|
||||
Values explicitly provided by the source API take precedence; LiteLLM
|
||||
metadata is used as a fallback.
|
||||
"""
|
||||
from onyx.llm.model_name_parser import parse_litellm_model_name
|
||||
|
||||
# LiteLLM keys are typically "provider/model"
|
||||
litellm_key = f"{provider}/{name}" if not name.startswith(f"{provider}/") else name
|
||||
parsed = parse_litellm_model_name(litellm_key)
|
||||
|
||||
# display_name: prefer API-provided name, then LiteLLM enrichment, then raw name
|
||||
if api_display_name and api_display_name != name:
|
||||
display_name = api_display_name
|
||||
else:
|
||||
display_name = parsed.display_name or name
|
||||
|
||||
# max_input_tokens: prefer API value, then LiteLLM lookup
|
||||
if api_max_input_tokens is not None:
|
||||
max_input_tokens: int | None = api_max_input_tokens
|
||||
else:
|
||||
try:
|
||||
max_input_tokens = get_max_input_tokens(name, provider)
|
||||
except Exception:
|
||||
max_input_tokens = None
|
||||
|
||||
# supports_image_input: prefer API value, then LiteLLM inference
|
||||
if api_supports_image_input is not None:
|
||||
supports_image = api_supports_image_input
|
||||
else:
|
||||
supports_image = litellm_thinks_model_supports_image_input(name, provider)
|
||||
|
||||
return CustomProviderModelResponse(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=supports_image,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_custom_models_from_api(
|
||||
provider: str,
|
||||
api_base: str,
|
||||
api_key: str | None,
|
||||
api_version: str | None,
|
||||
) -> list[CustomProviderModelResponse]:
|
||||
"""Hit an OpenAI-compatible ``/v1/models`` (or versioned variant)."""
|
||||
cleaned = api_base.strip().rstrip("/")
|
||||
if api_version:
|
||||
url = f"{cleaned}/{api_version.strip().strip('/')}/models"
|
||||
elif cleaned.endswith("/v1"):
|
||||
url = f"{cleaned}/models"
|
||||
else:
|
||||
url = f"{cleaned}/v1/models"
|
||||
|
||||
response_json = _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="Custom provider",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from the provider's API.",
|
||||
)
|
||||
|
||||
results: list[CustomProviderModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
results.append(
|
||||
_enrich_custom_model(
|
||||
model_id,
|
||||
provider,
|
||||
api_display_name=model.get("name"),
|
||||
api_max_input_tokens=model.get("context_length"),
|
||||
api_supports_image_input=infer_vision_support(model_id),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse custom provider model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from the provider's API.",
|
||||
)
|
||||
|
||||
return sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
|
||||
def _fetch_custom_models_from_litellm(
|
||||
provider: str,
|
||||
) -> list[CustomProviderModelResponse]:
|
||||
"""Fall back to litellm's static ``models_by_provider`` mapping."""
|
||||
import litellm
|
||||
|
||||
model_names = litellm.models_by_provider.get(provider)
|
||||
if model_names is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Unknown provider: {provider}",
|
||||
)
|
||||
return sorted(
|
||||
(_enrich_custom_model(name, provider) for name in model_names),
|
||||
key=lambda m: m.name.lower(),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/built-in/options")
|
||||
def fetch_llm_options(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
@@ -1211,17 +1330,16 @@ def get_ollama_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
"""Perform GET to OpenRouter /models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/models"
|
||||
headers: dict[str, str] = {
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
# Optional headers recommended by OpenRouter for attribution
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
@@ -1244,12 +1362,8 @@ def get_openrouter_available_models(
|
||||
Parses id, name (display), context_length, and architecture.input_modalities.
|
||||
"""
|
||||
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openrouter_models_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
data = response_json.get("data", [])
|
||||
@@ -1342,18 +1456,13 @@ def get_lm_studio_available_models(
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
# Only do so when the api_base matches what is stored.
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
if stored_base == cleaned_api_base:
|
||||
api_key = existing_provider.custom_config.get(
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
)
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
@@ -1437,12 +1546,8 @@ def get_litellm_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=api_key, api_base=request.api_base
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1499,7 +1604,7 @@ def get_litellm_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
@@ -1574,12 +1679,8 @@ def get_bifrost_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1668,12 +1769,8 @@ def get_openai_compatible_server_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
|
||||
@@ -477,6 +477,21 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
# Custom provider dynamic models fetch
|
||||
class CustomProviderModelsRequest(BaseModel):
|
||||
provider: str # LiteLLM provider slug (e.g. "deepseek", "fireworks_ai")
|
||||
api_base: str | None = None # If set, fetches live models via /v1/models
|
||||
api_key: str | None = None
|
||||
api_version: str | None = None # If set, used to construct the models URL
|
||||
|
||||
|
||||
class CustomProviderModelResponse(BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
# OpenAI Compatible dynamic models fetch
|
||||
class OpenAICompatibleModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
|
||||
@@ -65,7 +65,6 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
multi_model_chat_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
@@ -90,8 +89,7 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None,
|
||||
ge=0, # thousands of tokens; None = context-aware default
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
|
||||
10
backend/pyproject.toml
Normal file
10
backend/pyproject.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[project]
|
||||
name = "onyx-backend"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"onyx[backend,dev,ee]",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
onyx = { workspace = true }
|
||||
@@ -46,11 +46,11 @@ curl -LsSf https://astral.py/uv/install.sh | sh
|
||||
|
||||
1. Edit `pyproject.toml`
|
||||
2. Add/update/remove dependencies in the appropriate section:
|
||||
- `[dependency-groups]` for dev tools
|
||||
- `[project.dependencies]` for **shared** dependencies (used by both backend and model_server)
|
||||
- `[dependency-groups.backend]` for backend-only dependencies
|
||||
- `[dependency-groups.dev]` for dev tools
|
||||
- `[dependency-groups.ee]` for EE features
|
||||
- `[dependency-groups.model_server]` for model_server-only dependencies (ML packages)
|
||||
- `[project.optional-dependencies.backend]` for backend-only dependencies
|
||||
- `[project.optional-dependencies.model_server]` for model_server-only dependencies (ML packages)
|
||||
- `[project.optional-dependencies.ee]` for EE features
|
||||
3. Commit your changes - pre-commit hooks will automatically regenerate the lock file and requirements
|
||||
|
||||
### 3. Generating Lock File and Requirements
|
||||
@@ -64,10 +64,10 @@ To manually regenerate:
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --group backend -o backend/requirements/default.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --extra backend -o backend/requirements/default.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --group dev -o backend/requirements/dev.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --group ee -o backend/requirements/ee.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --group model_server -o backend/requirements/model_server.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --extra ee -o backend/requirements/ee.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --extra model_server -o backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
### 4. Installing Dependencies
|
||||
@@ -76,14 +76,30 @@ If enabled, all packages are installed automatically by the `uv-sync` pre-commit
|
||||
branches or pulling new changes.
|
||||
|
||||
```bash
|
||||
# For development (most common) — installs shared + backend + dev + ee
|
||||
uv sync
|
||||
# For everything (most common)
|
||||
uv sync --all-extras
|
||||
|
||||
# For backend production only (shared + backend dependencies)
|
||||
uv sync --no-default-groups --group backend
|
||||
# For backend production (shared + backend dependencies)
|
||||
uv sync --extra backend
|
||||
|
||||
# For backend development (shared + backend + dev tools)
|
||||
uv sync --extra backend --extra dev
|
||||
|
||||
# For backend with EE (shared + backend + ee)
|
||||
uv sync --extra backend --extra ee
|
||||
|
||||
# For model server (shared + model_server, NO backend deps!)
|
||||
uv sync --no-default-groups --group model_server
|
||||
uv sync --extra model_server
|
||||
```
|
||||
|
||||
`uv` aggressively [ignores active virtual environments](https://docs.astral.sh/uv/concepts/projects/config/#project-environment-path) and prefers the root virtual environment.
|
||||
When working in workspace packages, be sure to pass `--active` when syncing the virtual environment:
|
||||
|
||||
```bash
|
||||
cd backend/
|
||||
source .venv/bin/activate
|
||||
uv sync --active
|
||||
uv run --active ...
|
||||
```
|
||||
|
||||
### 5. Upgrading Dependencies
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --group backend -o backend/requirements/default.txt
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --extra backend -o backend/requirements/default.txt
|
||||
agent-client-protocol==0.7.1
|
||||
# via onyx
|
||||
aioboto3==15.1.0
|
||||
@@ -19,6 +19,7 @@ aiohttp==3.13.4
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
# litellm
|
||||
# onyx
|
||||
# voyageai
|
||||
aioitertools==0.13.0
|
||||
# via aiobotocore
|
||||
@@ -27,6 +28,7 @@ aiolimiter==1.2.1
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
alembic==1.10.4
|
||||
# via onyx
|
||||
amqp==5.3.1
|
||||
# via kombu
|
||||
annotated-doc==0.0.4
|
||||
@@ -49,10 +51,13 @@ argon2-cffi==23.1.0
|
||||
argon2-cffi-bindings==25.1.0
|
||||
# via argon2-cffi
|
||||
asana==5.0.8
|
||||
# via onyx
|
||||
async-timeout==5.0.1 ; python_full_version < '3.11.3'
|
||||
# via redis
|
||||
asyncpg==0.30.0
|
||||
# via onyx
|
||||
atlassian-python-api==3.41.16
|
||||
# via onyx
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -63,6 +68,7 @@ attrs==25.4.0
|
||||
authlib==1.6.9
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via onyx
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
backoff==2.2.1
|
||||
@@ -80,6 +86,7 @@ beautifulsoup4==4.12.3
|
||||
# atlassian-python-api
|
||||
# markdownify
|
||||
# markitdown
|
||||
# onyx
|
||||
# unstructured
|
||||
billiard==4.2.3
|
||||
# via celery
|
||||
@@ -87,7 +94,9 @@ boto3==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
# cohere
|
||||
# onyx
|
||||
boto3-stubs==1.39.11
|
||||
# via onyx
|
||||
botocore==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -96,6 +105,7 @@ botocore==1.39.11
|
||||
botocore-stubs==1.40.74
|
||||
# via boto3-stubs
|
||||
braintrust==0.3.9
|
||||
# via onyx
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
bytecode==0.17.0
|
||||
@@ -105,6 +115,7 @@ cachetools==6.2.2
|
||||
caio==0.9.25
|
||||
# via aiofile
|
||||
celery==5.5.1
|
||||
# via onyx
|
||||
certifi==2025.11.12
|
||||
# via
|
||||
# asana
|
||||
@@ -123,6 +134,7 @@ cffi==2.0.0
|
||||
# pynacl
|
||||
# zstandard
|
||||
chardet==5.2.0
|
||||
# via onyx
|
||||
charset-normalizer==3.4.4
|
||||
# via
|
||||
# htmldate
|
||||
@@ -134,6 +146,7 @@ charset-normalizer==3.4.4
|
||||
chevron==0.14.0
|
||||
# via braintrust
|
||||
chonkie==1.0.10
|
||||
# via onyx
|
||||
claude-agent-sdk==0.1.19
|
||||
# via onyx
|
||||
click==8.3.1
|
||||
@@ -188,12 +201,15 @@ cryptography==46.0.6
|
||||
cyclopts==4.2.4
|
||||
# via fastmcp
|
||||
dask==2026.1.1
|
||||
# via distributed
|
||||
# via
|
||||
# distributed
|
||||
# onyx
|
||||
dataclasses-json==0.6.7
|
||||
# via unstructured
|
||||
dateparser==1.2.2
|
||||
# via htmldate
|
||||
ddtrace==3.10.0
|
||||
# via onyx
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
defusedxml==0.7.1
|
||||
@@ -207,6 +223,7 @@ deprecated==1.3.1
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distributed==2026.1.1
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via
|
||||
# openai
|
||||
@@ -218,6 +235,7 @@ docstring-parser==0.17.0
|
||||
docutils==0.22.3
|
||||
# via rich-rst
|
||||
dropbox==12.0.2
|
||||
# via onyx
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
email-validator==2.2.0
|
||||
@@ -233,6 +251,7 @@ et-xmlfile==2.0.0
|
||||
events==0.5
|
||||
# via opensearch-py
|
||||
exa-py==1.15.4
|
||||
# via onyx
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
@@ -243,16 +262,23 @@ fastapi==0.133.1
|
||||
# fastapi-users
|
||||
# onyx
|
||||
fastapi-limiter==0.1.6
|
||||
# via onyx
|
||||
fastapi-users==15.0.4
|
||||
# via fastapi-users-db-sqlalchemy
|
||||
# via
|
||||
# fastapi-users-db-sqlalchemy
|
||||
# onyx
|
||||
fastapi-users-db-sqlalchemy==7.0.0
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
fastmcp==3.2.0
|
||||
# via onyx
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
filelock==3.20.3
|
||||
# via huggingface-hub
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onyx
|
||||
filetype==1.2.0
|
||||
# via unstructured
|
||||
flatbuffers==25.9.23
|
||||
@@ -272,6 +298,7 @@ gitpython==3.1.45
|
||||
google-api-core==2.28.1
|
||||
# via google-api-python-client
|
||||
google-api-python-client==2.86.0
|
||||
# via onyx
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
@@ -281,8 +308,11 @@ google-auth==2.48.0
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-auth-httplib2==0.1.0
|
||||
# via google-api-python-client
|
||||
# via
|
||||
# google-api-python-client
|
||||
# onyx
|
||||
google-auth-oauthlib==1.0.0
|
||||
# via onyx
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
googleapis-common-protos==1.72.0
|
||||
@@ -310,6 +340,7 @@ htmldate==1.9.1
|
||||
httpcore==1.0.9
|
||||
# via
|
||||
# httpx
|
||||
# onyx
|
||||
# unstructured-client
|
||||
httplib2==0.31.0
|
||||
# via
|
||||
@@ -326,16 +357,21 @@ httpx==0.28.1
|
||||
# langsmith
|
||||
# litellm
|
||||
# mcp
|
||||
# onyx
|
||||
# openai
|
||||
# unstructured-client
|
||||
httpx-oauth==0.15.1
|
||||
# via onyx
|
||||
httpx-sse==0.4.3
|
||||
# via
|
||||
# cohere
|
||||
# mcp
|
||||
hubspot-api-client==11.1.0
|
||||
# via onyx
|
||||
huggingface-hub==0.35.3
|
||||
# via tokenizers
|
||||
# via
|
||||
# onyx
|
||||
# tokenizers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
hyperframe==6.1.0
|
||||
@@ -354,7 +390,9 @@ importlib-metadata==8.7.0
|
||||
# litellm
|
||||
# opentelemetry-api
|
||||
inflection==0.5.1
|
||||
# via pyairtable
|
||||
# via
|
||||
# onyx
|
||||
# pyairtable
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
isodate==0.7.2
|
||||
@@ -376,6 +414,7 @@ jinja2==3.1.6
|
||||
# distributed
|
||||
# litellm
|
||||
jira==3.10.5
|
||||
# via onyx
|
||||
jiter==0.12.0
|
||||
# via openai
|
||||
jmespath==1.0.1
|
||||
@@ -391,7 +430,9 @@ jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
jsonref==1.1.0
|
||||
# via fastmcp
|
||||
# via
|
||||
# fastmcp
|
||||
# onyx
|
||||
jsonschema==4.25.1
|
||||
# via
|
||||
# litellm
|
||||
@@ -409,12 +450,15 @@ kombu==5.5.4
|
||||
kubernetes==31.0.0
|
||||
# via onyx
|
||||
langchain-core==1.2.22
|
||||
# via onyx
|
||||
langdetect==1.0.9
|
||||
# via unstructured
|
||||
langfuse==3.10.0
|
||||
# via onyx
|
||||
langsmith==0.3.45
|
||||
# via langchain-core
|
||||
lazy-imports==1.0.1
|
||||
# via onyx
|
||||
legacy-cgi==2.6.4 ; python_full_version >= '3.13'
|
||||
# via ddtrace
|
||||
litellm==1.81.6
|
||||
@@ -429,6 +473,7 @@ lxml==5.3.0
|
||||
# justext
|
||||
# lxml-html-clean
|
||||
# markitdown
|
||||
# onyx
|
||||
# python-docx
|
||||
# python-pptx
|
||||
# python3-saml
|
||||
@@ -443,7 +488,9 @@ magika==0.6.3
|
||||
makefun==1.16.0
|
||||
# via fastapi-users
|
||||
mako==1.2.4
|
||||
# via alembic
|
||||
# via
|
||||
# alembic
|
||||
# onyx
|
||||
mammoth==1.11.0
|
||||
# via markitdown
|
||||
markdown-it-py==4.0.0
|
||||
@@ -451,6 +498,7 @@ markdown-it-py==4.0.0
|
||||
markdownify==1.2.2
|
||||
# via markitdown
|
||||
markitdown==0.1.2
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
@@ -464,9 +512,11 @@ mcp==1.26.0
|
||||
# via
|
||||
# claude-agent-sdk
|
||||
# fastmcp
|
||||
# onyx
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==3.2.0
|
||||
# via onyx
|
||||
more-itertools==10.8.0
|
||||
# via
|
||||
# jaraco-classes
|
||||
@@ -475,10 +525,13 @@ more-itertools==10.8.0
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
msal==1.34.0
|
||||
# via office365-rest-python-client
|
||||
# via
|
||||
# office365-rest-python-client
|
||||
# onyx
|
||||
msgpack==1.1.2
|
||||
# via distributed
|
||||
msoffcrypto-tool==5.4.2
|
||||
# via onyx
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -495,6 +548,7 @@ mypy-extensions==1.0.0
|
||||
# mypy
|
||||
# typing-inspect
|
||||
nest-asyncio==1.6.0
|
||||
# via onyx
|
||||
nltk==3.9.4
|
||||
# via unstructured
|
||||
numpy==2.4.1
|
||||
@@ -509,8 +563,10 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# atlassian-python-api
|
||||
# kubernetes
|
||||
# onyx
|
||||
# requests-oauthlib
|
||||
office365-rest-python-client==2.6.2
|
||||
# via onyx
|
||||
olefile==0.47
|
||||
# via
|
||||
# msoffcrypto-tool
|
||||
@@ -526,11 +582,15 @@ openai==2.14.0
|
||||
openapi-pydantic==0.5.1
|
||||
# via fastmcp
|
||||
openinference-instrumentation==0.1.42
|
||||
# via onyx
|
||||
openinference-semantic-conventions==0.1.25
|
||||
# via openinference-instrumentation
|
||||
openpyxl==3.0.10
|
||||
# via markitdown
|
||||
# via
|
||||
# markitdown
|
||||
# onyx
|
||||
opensearch-py==3.0.0
|
||||
# via onyx
|
||||
opentelemetry-api==1.39.1
|
||||
# via
|
||||
# ddtrace
|
||||
@@ -546,6 +606,7 @@ opentelemetry-exporter-otlp-proto-http==1.39.1
|
||||
# via langfuse
|
||||
opentelemetry-proto==1.39.1
|
||||
# via
|
||||
# onyx
|
||||
# opentelemetry-exporter-otlp-proto-common
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-sdk==1.39.1
|
||||
@@ -579,6 +640,7 @@ parameterized==0.9.0
|
||||
partd==1.4.2
|
||||
# via dask
|
||||
passlib==1.7.4
|
||||
# via onyx
|
||||
pathable==0.4.4
|
||||
# via jsonschema-path
|
||||
pdfminer-six==20251107
|
||||
@@ -590,7 +652,9 @@ platformdirs==4.5.0
|
||||
# fastmcp
|
||||
# zeep
|
||||
playwright==1.55.0
|
||||
# via pytest-playwright
|
||||
# via
|
||||
# onyx
|
||||
# pytest-playwright
|
||||
pluggy==1.6.0
|
||||
# via pytest
|
||||
ply==3.11
|
||||
@@ -620,9 +684,12 @@ protobuf==6.33.5
|
||||
psutil==7.1.3
|
||||
# via
|
||||
# distributed
|
||||
# onyx
|
||||
# unstructured
|
||||
psycopg2-binary==2.9.9
|
||||
# via onyx
|
||||
puremagic==1.28
|
||||
# via onyx
|
||||
pwdlib==0.3.0
|
||||
# via fastapi-users
|
||||
py==1.11.0
|
||||
@@ -630,6 +697,7 @@ py==1.11.0
|
||||
py-key-value-aio==0.4.4
|
||||
# via fastmcp
|
||||
pyairtable==3.0.1
|
||||
# via onyx
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
@@ -639,6 +707,7 @@ pyasn1-modules==0.4.2
|
||||
pycparser==2.23 ; implementation_name != 'PyPy'
|
||||
# via cffi
|
||||
pycryptodome==3.19.1
|
||||
# via onyx
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# agent-client-protocol
|
||||
@@ -665,6 +734,7 @@ pydantic-settings==2.12.0
|
||||
pyee==13.0.0
|
||||
# via playwright
|
||||
pygithub==2.5.0
|
||||
# via onyx
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
pyjwt==2.12.0
|
||||
@@ -675,13 +745,17 @@ pyjwt==2.12.0
|
||||
# pygithub
|
||||
# simple-salesforce
|
||||
pympler==1.1
|
||||
# via onyx
|
||||
pynacl==1.6.2
|
||||
# via pygithub
|
||||
pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.9.2
|
||||
# via unstructured-client
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
pyperclip==1.11.0
|
||||
# via fastmcp
|
||||
pyreadline3==3.5.4 ; sys_platform == 'win32'
|
||||
@@ -694,7 +768,9 @@ pytest==8.3.5
|
||||
pytest-base-url==2.1.0
|
||||
# via pytest-playwright
|
||||
pytest-mock==3.12.0
|
||||
# via onyx
|
||||
pytest-playwright==0.7.0
|
||||
# via onyx
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -705,9 +781,11 @@ python-dateutil==2.8.2
|
||||
# htmldate
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
# onyx
|
||||
# opensearch-py
|
||||
# pandas
|
||||
python-docx==1.1.2
|
||||
# via onyx
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
# braintrust
|
||||
@@ -715,8 +793,10 @@ python-dotenv==1.1.1
|
||||
# litellm
|
||||
# magika
|
||||
# mcp
|
||||
# onyx
|
||||
# pydantic-settings
|
||||
python-gitlab==5.6.0
|
||||
# via onyx
|
||||
python-http-client==3.3.7
|
||||
# via sendgrid
|
||||
python-iso639==2025.11.16
|
||||
@@ -727,15 +807,19 @@ python-multipart==0.0.22
|
||||
# via
|
||||
# fastapi-users
|
||||
# mcp
|
||||
# onyx
|
||||
python-oxmsg==0.0.2
|
||||
# via unstructured
|
||||
python-pptx==0.6.23
|
||||
# via markitdown
|
||||
# via
|
||||
# markitdown
|
||||
# onyx
|
||||
python-slugify==8.0.4
|
||||
# via
|
||||
# braintrust
|
||||
# pytest-playwright
|
||||
python3-saml==1.15.0
|
||||
# via onyx
|
||||
pytz==2025.2
|
||||
# via
|
||||
# dateparser
|
||||
@@ -743,6 +827,7 @@ pytz==2025.2
|
||||
# pandas
|
||||
# zeep
|
||||
pywikibot==9.0.0
|
||||
# via onyx
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via
|
||||
# mcp
|
||||
@@ -759,9 +844,13 @@ pyyaml==6.0.3
|
||||
# kubernetes
|
||||
# langchain-core
|
||||
rapidfuzz==3.13.0
|
||||
# via unstructured
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
redis==5.0.8
|
||||
# via fastapi-limiter
|
||||
# via
|
||||
# fastapi-limiter
|
||||
# onyx
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# jsonschema
|
||||
@@ -792,6 +881,7 @@ requests==2.33.0
|
||||
# matrix-client
|
||||
# msal
|
||||
# office365-rest-python-client
|
||||
# onyx
|
||||
# opensearch-py
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# pyairtable
|
||||
@@ -817,6 +907,7 @@ requests-oauthlib==1.3.1
|
||||
# google-auth-oauthlib
|
||||
# jira
|
||||
# kubernetes
|
||||
# onyx
|
||||
requests-toolbelt==1.0.0
|
||||
# via
|
||||
# jira
|
||||
@@ -827,6 +918,7 @@ requests-toolbelt==1.0.0
|
||||
retry==0.9.2
|
||||
# via onyx
|
||||
rfc3986==1.5.0
|
||||
# via onyx
|
||||
rich==14.2.0
|
||||
# via
|
||||
# cyclopts
|
||||
@@ -846,12 +938,15 @@ s3transfer==0.13.1
|
||||
secretstorage==3.5.0 ; sys_platform == 'linux'
|
||||
# via keyring
|
||||
sendgrid==6.12.5
|
||||
# via onyx
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shapely==2.0.6
|
||||
# via onyx
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
simple-salesforce==1.12.6
|
||||
# via onyx
|
||||
six==1.17.0
|
||||
# via
|
||||
# asana
|
||||
@@ -866,6 +961,7 @@ six==1.17.0
|
||||
# python-dateutil
|
||||
# stone
|
||||
slack-sdk==3.20.2
|
||||
# via onyx
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
@@ -880,6 +976,7 @@ sqlalchemy==2.0.15
|
||||
# via
|
||||
# alembic
|
||||
# fastapi-users-db-sqlalchemy
|
||||
# onyx
|
||||
sse-starlette==3.0.3
|
||||
# via mcp
|
||||
sseclient-py==1.8.0
|
||||
@@ -888,11 +985,14 @@ starlette==0.49.3
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# onyx
|
||||
# prometheus-fastapi-instrumentator
|
||||
stone==3.3.1
|
||||
# via dropbox
|
||||
stripe==10.12.0
|
||||
# via onyx
|
||||
supervisor==4.3.0
|
||||
# via onyx
|
||||
sympy==1.14.0
|
||||
# via onnxruntime
|
||||
tblib==3.2.2
|
||||
@@ -905,8 +1005,11 @@ tenacity==9.1.2
|
||||
text-unidecode==1.3
|
||||
# via python-slugify
|
||||
tiktoken==0.7.0
|
||||
# via litellm
|
||||
# via
|
||||
# litellm
|
||||
# onyx
|
||||
timeago==1.0.16
|
||||
# via onyx
|
||||
tld==0.13.1
|
||||
# via courlan
|
||||
tokenizers==0.21.4
|
||||
@@ -930,11 +1033,13 @@ tqdm==4.67.1
|
||||
# openai
|
||||
# unstructured
|
||||
trafilatura==1.12.2
|
||||
# via onyx
|
||||
typer==0.20.0
|
||||
# via mcp
|
||||
types-awscrt==0.28.4
|
||||
# via botocore-stubs
|
||||
types-openpyxl==3.0.4.7
|
||||
# via onyx
|
||||
types-requests==2.32.0.20250328
|
||||
# via cohere
|
||||
types-s3transfer==0.14.0
|
||||
@@ -1000,8 +1105,11 @@ tzlocal==5.3.1
|
||||
uncalled-for==0.2.0
|
||||
# via fastmcp
|
||||
unstructured==0.18.27
|
||||
# via onyx
|
||||
unstructured-client==0.42.6
|
||||
# via unstructured
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
uritemplate==4.2.0
|
||||
# via google-api-python-client
|
||||
urllib3==2.6.3
|
||||
@@ -1013,6 +1121,7 @@ urllib3==2.6.3
|
||||
# htmldate
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
# onyx
|
||||
# opensearch-py
|
||||
# pyairtable
|
||||
# pygithub
|
||||
@@ -1062,7 +1171,9 @@ xlrd==2.0.2
|
||||
xlsxwriter==3.2.9
|
||||
# via python-pptx
|
||||
xmlsec==1.3.14
|
||||
# via python3-saml
|
||||
# via
|
||||
# onyx
|
||||
# python3-saml
|
||||
xmltodict==1.0.2
|
||||
# via ddtrace
|
||||
yarl==1.22.0
|
||||
@@ -1076,3 +1187,4 @@ zipp==3.23.0
|
||||
zstandard==0.23.0
|
||||
# via langsmith
|
||||
zulip==0.8.2
|
||||
# via onyx
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --group dev -o backend/requirements/dev.txt
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --extra dev -o backend/requirements/dev.txt
|
||||
agent-client-protocol==0.7.1
|
||||
# via onyx
|
||||
aioboto3==15.1.0
|
||||
@@ -47,6 +47,7 @@ attrs==25.4.0
|
||||
# jsonschema
|
||||
# referencing
|
||||
black==25.1.0
|
||||
# via onyx
|
||||
boto3==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -59,6 +60,7 @@ botocore==1.39.11
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
celery-types==0.19.0
|
||||
# via onyx
|
||||
certifi==2025.11.12
|
||||
# via
|
||||
# httpcore
|
||||
@@ -120,6 +122,7 @@ execnet==2.1.2
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==40.1.2
|
||||
# via onyx
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -153,6 +156,7 @@ h11==0.16.0
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hatchling==1.28.0
|
||||
# via onyx
|
||||
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
@@ -183,6 +187,7 @@ importlib-metadata==8.7.0
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
ipykernel==6.29.5
|
||||
# via onyx
|
||||
ipython==9.7.0
|
||||
# via ipykernel
|
||||
ipython-pygments-lexers==1.1.1
|
||||
@@ -219,11 +224,13 @@ litellm==1.81.6
|
||||
mako==1.2.4
|
||||
# via alembic
|
||||
manygo==0.2.0
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
# mako
|
||||
matplotlib==3.10.8
|
||||
# via onyx
|
||||
matplotlib-inline==0.2.1
|
||||
# via
|
||||
# ipykernel
|
||||
@@ -236,10 +243,12 @@ multidict==6.7.0
|
||||
# aiohttp
|
||||
# yarl
|
||||
mypy==1.13.0
|
||||
# via onyx
|
||||
mypy-extensions==1.0.0
|
||||
# via
|
||||
# black
|
||||
# mypy
|
||||
# onyx
|
||||
nest-asyncio==1.6.0
|
||||
# via ipykernel
|
||||
nodeenv==1.9.1
|
||||
@@ -255,12 +264,15 @@ oauthlib==3.2.2
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.3
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
# litellm
|
||||
# onyx
|
||||
openapi-generator-cli==7.17.0
|
||||
# via onyx-devtools
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
packaging==24.2
|
||||
# via
|
||||
# black
|
||||
@@ -270,6 +282,7 @@ packaging==24.2
|
||||
# matplotlib
|
||||
# pytest
|
||||
pandas-stubs==2.3.3.251201
|
||||
# via onyx
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
parso==0.8.5
|
||||
@@ -292,6 +305,7 @@ pluggy==1.6.0
|
||||
# hatchling
|
||||
# pytest
|
||||
pre-commit==3.2.2
|
||||
# via onyx
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -345,16 +359,22 @@ pyparsing==3.2.5
|
||||
# via matplotlib
|
||||
pytest==8.3.5
|
||||
# via
|
||||
# onyx
|
||||
# pytest-alembic
|
||||
# pytest-asyncio
|
||||
# pytest-dotenv
|
||||
# pytest-repeat
|
||||
# pytest-xdist
|
||||
pytest-alembic==0.12.1
|
||||
# via onyx
|
||||
pytest-asyncio==1.3.0
|
||||
# via onyx
|
||||
pytest-dotenv==0.5.2
|
||||
# via onyx
|
||||
pytest-repeat==0.9.4
|
||||
# via onyx
|
||||
pytest-xdist==3.8.0
|
||||
# via onyx
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -387,7 +407,9 @@ referencing==0.36.2
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
@@ -408,6 +430,7 @@ rpds-py==0.29.0
|
||||
rsa==4.9.1
|
||||
# via google-auth
|
||||
ruff==0.12.0
|
||||
# via onyx
|
||||
s3transfer==0.13.1
|
||||
# via boto3
|
||||
sentry-sdk==2.14.0
|
||||
@@ -461,22 +484,39 @@ traitlets==5.14.3
|
||||
trove-classifiers==2025.12.1.14
|
||||
# via hatchling
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
# via onyx
|
||||
types-html5lib==1.1.11.13
|
||||
# via types-beautifulsoup4
|
||||
# via
|
||||
# onyx
|
||||
# types-beautifulsoup4
|
||||
types-oauthlib==3.2.0.9
|
||||
# via onyx
|
||||
types-passlib==1.7.7.20240106
|
||||
# via onyx
|
||||
types-pillow==10.2.0.20240822
|
||||
# via onyx
|
||||
types-psutil==7.1.3.20251125
|
||||
# via onyx
|
||||
types-psycopg2==2.9.21.10
|
||||
# via onyx
|
||||
types-python-dateutil==2.8.19.13
|
||||
# via onyx
|
||||
types-pytz==2023.3.1.1
|
||||
# via pandas-stubs
|
||||
# via
|
||||
# onyx
|
||||
# pandas-stubs
|
||||
types-pyyaml==6.0.12.11
|
||||
# via onyx
|
||||
types-regex==2023.3.23.1
|
||||
# via onyx
|
||||
types-requests==2.32.0.20250328
|
||||
# via cohere
|
||||
# via
|
||||
# cohere
|
||||
# onyx
|
||||
types-retry==0.9.9.3
|
||||
# via onyx
|
||||
types-setuptools==68.0.0.3
|
||||
# via onyx
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
@@ -534,3 +574,4 @@ yarl==1.22.0
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
zizmor==1.18.0
|
||||
# via onyx
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --group ee -o backend/requirements/ee.txt
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --extra ee -o backend/requirements/ee.txt
|
||||
agent-client-protocol==0.7.1
|
||||
# via onyx
|
||||
aioboto3==15.1.0
|
||||
@@ -182,6 +182,7 @@ packaging==24.2
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
posthog==3.7.4
|
||||
# via onyx
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
# onyx
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --group model_server -o backend/requirements/model_server.txt
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --extra model_server -o backend/requirements/model_server.txt
|
||||
accelerate==1.6.0
|
||||
# via onyx
|
||||
agent-client-protocol==0.7.1
|
||||
# via onyx
|
||||
aioboto3==15.1.0
|
||||
@@ -104,6 +105,7 @@ distro==1.9.0
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
# via onyx
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -205,6 +207,7 @@ networkx==3.5
|
||||
numpy==2.4.1
|
||||
# via
|
||||
# accelerate
|
||||
# onyx
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# transformers
|
||||
@@ -360,6 +363,7 @@ s3transfer==0.13.1
|
||||
safetensors==0.5.3
|
||||
# via
|
||||
# accelerate
|
||||
# onyx
|
||||
# transformers
|
||||
scikit-learn==1.7.2
|
||||
# via sentence-transformers
|
||||
@@ -368,6 +372,7 @@ scipy==1.16.3
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
sentence-transformers==4.0.2
|
||||
# via onyx
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
setuptools==80.9.0 ; python_full_version >= '3.12'
|
||||
@@ -406,6 +411,7 @@ tokenizers==0.21.4
|
||||
torch==2.9.1
|
||||
# via
|
||||
# accelerate
|
||||
# onyx
|
||||
# sentence-transformers
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
@@ -414,7 +420,9 @@ tqdm==4.67.1
|
||||
# sentence-transformers
|
||||
# transformers
|
||||
transformers==4.53.0
|
||||
# via sentence-transformers
|
||||
# via
|
||||
# onyx
|
||||
# sentence-transformers
|
||||
triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
types-requests==2.32.0.20250328
|
||||
|
||||
@@ -9,7 +9,6 @@ from unittest.mock import patch
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from ee.onyx.db.license import delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import get_used_seats
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
@@ -215,43 +214,3 @@ class TestCheckSeatAvailabilityMultiTenant:
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
|
||||
class TestGetUsedSeatsAccountTypeFiltering:
|
||||
"""Verify get_used_seats query excludes SERVICE_ACCOUNT but includes BOT."""
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_excludes_service_accounts(self, mock_get_session: MagicMock) -> None:
|
||||
"""SERVICE_ACCOUNT users should not count toward seats."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 5
|
||||
|
||||
result = get_used_seats()
|
||||
|
||||
assert result == 5
|
||||
# Inspect the compiled query to verify account_type filter
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "SERVICE_ACCOUNT" in compiled
|
||||
# BOT should NOT be excluded
|
||||
assert "BOT" not in compiled
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_still_excludes_ext_perm_user(self, mock_get_session: MagicMock) -> None:
|
||||
"""EXT_PERM_USER exclusion should still be present."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 3
|
||||
|
||||
get_used_seats()
|
||||
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "EXT_PERM_USER" in compiled
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from onyx.connectors.google_utils.google_kv import get_auth_url
|
||||
from onyx.connectors.google_utils.google_kv import get_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import get_service_account_key
|
||||
from onyx.connectors.google_utils.google_kv import upsert_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import upsert_service_account_key
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleAppWebCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
|
||||
|
||||
def _make_app_creds() -> GoogleAppCredentials:
|
||||
return GoogleAppCredentials(
|
||||
web=GoogleAppWebCredentials(
|
||||
client_id="client-id.apps.googleusercontent.com",
|
||||
project_id="test-project",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_secret="secret",
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
javascript_origins=["https://example.com"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _make_service_account_key() -> GoogleServiceAccountKey:
|
||||
return GoogleServiceAccountKey(
|
||||
type="service_account",
|
||||
project_id="test-project",
|
||||
private_key_id="private-key-id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
|
||||
client_email="test@test-project.iam.gserviceaccount.com",
|
||||
client_id="123",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test",
|
||||
universe_domain="googleapis.com",
|
||||
)
|
||||
|
||||
|
||||
def test_upsert_google_app_cred_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_google_app_cred(_make_app_creds(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["web"]["client_id"] == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
def test_upsert_service_account_key_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_service_account_key(_make_service_account_key(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["project_id"] == "test-project"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_google_app_cred_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload: dict[str, Any] = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
creds = get_google_app_cred(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert creds.web.client_id == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_service_account_key_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
stored_value: object = (
|
||||
_make_service_account_key().model_dump(mode="json")
|
||||
if not legacy_string
|
||||
else _make_service_account_key().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
key = get_service_account_key(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert key.client_email == "test@test-project.iam.gserviceaccount.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_auth_url_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
stored_state: dict[str, object] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored_state["key"] = key
|
||||
stored_state["value"] = value
|
||||
stored_state["encrypt"] = encrypt
|
||||
|
||||
class _StubFlow:
|
||||
def authorization_url(self, prompt: str) -> tuple[str, None]:
|
||||
assert prompt == "consent"
|
||||
return "https://accounts.google.com/o/oauth2/auth?state=test-state", None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
def _from_client_config(
|
||||
_app_config: object, *, scopes: object, redirect_uri: object
|
||||
) -> _StubFlow:
|
||||
del scopes, redirect_uri
|
||||
return _StubFlow()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.InstalledAppFlow.from_client_config",
|
||||
_from_client_config,
|
||||
)
|
||||
|
||||
auth_url = get_auth_url(42, DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert auth_url.startswith("https://accounts.google.com")
|
||||
assert stored_state["value"] == {"value": "test-state"}
|
||||
assert stored_state["encrypt"] is True
|
||||
@@ -6,7 +6,6 @@ import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import _JIRA_BULK_FETCH_LIMIT
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
@@ -146,29 +145,3 @@ def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
|
||||
|
||||
def test_bulk_fetch_respects_api_batch_limit() -> None:
|
||||
"""Requests to the bulkfetch endpoint never exceed _JIRA_BULK_FETCH_LIMIT IDs."""
|
||||
client = _mock_jira_client()
|
||||
total_issues = _JIRA_BULK_FETCH_LIMIT * 3 + 7
|
||||
all_ids = [str(i) for i in range(total_issues)]
|
||||
|
||||
batch_sizes: list[int] = []
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
batch_sizes.append(len(ids))
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, all_ids)
|
||||
|
||||
assert len(result) == total_issues
|
||||
# keeping this hardcoded because it's the documented limit
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
assert all(size <= 100 for size in batch_sizes)
|
||||
assert len(batch_sizes) == 4
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
"""Tests for _build_thread_text function."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.slack_search import _build_thread_text
|
||||
|
||||
|
||||
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
|
||||
return {"user": user, "text": text, "ts": ts}
|
||||
|
||||
|
||||
class TestBuildThreadText:
|
||||
"""Verify _build_thread_text includes full thread replies up to cap."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
|
||||
"""All replies within cap are included in output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "parent msg", "1000.0"),
|
||||
_make_msg("U2", "reply 1", "1001.0"),
|
||||
_make_msg("U3", "reply 2", "1002.0"),
|
||||
_make_msg("U4", "reply 3", "1003.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "parent msg" in result
|
||||
assert "reply 1" in result
|
||||
assert "reply 2" in result
|
||||
assert "reply 3" in result
|
||||
assert "..." not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
|
||||
"""Single message (no replies) returns just the parent text."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [_make_msg("U1", "just a message", "1000.0")]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "just a message" in result
|
||||
assert "Replies:" not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
|
||||
"""Thread parent message is always the first line of output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "I am the parent", "1000.0"),
|
||||
_make_msg("U2", "I am a reply", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
parent_pos = result.index("I am the parent")
|
||||
reply_pos = result.index("I am a reply")
|
||||
assert parent_pos < reply_pos
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
|
||||
"""User IDs in thread text are replaced with display names."""
|
||||
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
|
||||
messages = [
|
||||
_make_msg("U1", "hello", "1000.0"),
|
||||
_make_msg("U2", "world", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "<@U1>" not in result
|
||||
assert "<@U2>" not in result
|
||||
@@ -1,108 +0,0 @@
|
||||
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
|
||||
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
|
||||
|
||||
|
||||
class TestExtractSlackMessageUrls:
|
||||
"""Verify URL parsing extracts channel_id and timestamp correctly."""
|
||||
|
||||
def test_standard_url(self) -> None:
|
||||
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 1
|
||||
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
|
||||
|
||||
def test_multiple_urls(self) -> None:
|
||||
query = (
|
||||
"compare https://co.slack.com/archives/C111/p1234567890123456 "
|
||||
"and https://co.slack.com/archives/C222/p9876543210987654"
|
||||
)
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 2
|
||||
assert results[0] == ("C111", "1234567890.123456")
|
||||
assert results[1] == ("C222", "9876543210.987654")
|
||||
|
||||
def test_no_urls(self) -> None:
|
||||
query = "what happened in #general last week?"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_non_slack_url_ignored(self) -> None:
|
||||
query = "check https://google.com/archives/C111/p1234567890123456"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_timestamp_conversion(self) -> None:
|
||||
"""p prefix removed, dot inserted after 10th digit."""
|
||||
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
channel_id, ts = results[0]
|
||||
assert channel_id == "CABC123"
|
||||
assert ts == "1775491616.524769"
|
||||
assert not ts.startswith("p")
|
||||
assert "." in ts
|
||||
|
||||
|
||||
class TestFetchThreadFromUrl:
|
||||
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search._build_thread_text")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_successful_fetch(
|
||||
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
|
||||
# Mock conversations_replies
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = [
|
||||
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
|
||||
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
|
||||
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
|
||||
]
|
||||
mock_client.conversations_replies.return_value = mock_response
|
||||
|
||||
# Mock channel info
|
||||
mock_ch_response = MagicMock()
|
||||
mock_ch_response.get.return_value = {"name": "general"}
|
||||
mock_client.conversations_info.return_value = mock_ch_response
|
||||
|
||||
mock_build_thread.return_value = (
|
||||
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(
|
||||
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
|
||||
)
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
|
||||
assert len(result.messages) == 1
|
||||
msg = result.messages[0]
|
||||
assert msg.channel_id == "C097NBWMY8Y"
|
||||
assert msg.thread_id is None # Prevents double-enrichment
|
||||
assert msg.slack_score == 100000.0
|
||||
assert "parent" in msg.text
|
||||
mock_client.conversations_replies.assert_called_once_with(
|
||||
channel="C097NBWMY8Y", ts="1775491616.524769"
|
||||
)
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
mock_client.conversations_replies.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response=MagicMock(status_code=404),
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
assert len(result.messages) == 0
|
||||
@@ -505,7 +505,6 @@ class TestGetLMStudioAvailableModels:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.api_base = "http://localhost:1234"
|
||||
mock_provider.custom_config = {"LM_STUDIO_API_KEY": "stored-secret"}
|
||||
|
||||
response = {
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -10,9 +9,7 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _check_seat_availability
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -744,80 +741,3 @@ class TestEmailCasePreservation:
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
|
||||
class TestSeatLock:
|
||||
"""Tests for the advisory lock in _check_seat_availability."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
|
||||
def test_acquires_advisory_lock_before_checking(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The advisory lock must be acquired before the seat check runs."""
|
||||
call_order: list[str] = []
|
||||
|
||||
def track_execute(stmt: Any, _params: Any = None) -> None:
|
||||
if "pg_advisory_xact_lock" in str(stmt):
|
||||
call_order.append("lock")
|
||||
|
||||
mock_dal.session.execute.side_effect = track_execute
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
|
||||
) as mock_fetch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_fn = MagicMock(return_value=mock_result)
|
||||
mock_fetch.return_value = mock_fn
|
||||
|
||||
def track_check(*_args: Any, **_kwargs: Any) -> Any:
|
||||
call_order.append("check")
|
||||
return mock_result
|
||||
|
||||
mock_fn.side_effect = track_check
|
||||
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
assert call_order == ["lock", "check"]
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
|
||||
def test_lock_uses_tenant_scoped_key(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_check = MagicMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=mock_check,
|
||||
):
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
mock_dal.session.execute.assert_called_once()
|
||||
params = mock_dal.session.execute.call_args[0][1]
|
||||
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
|
||||
|
||||
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
|
||||
"""Lock id must be deterministic and differ across tenants."""
|
||||
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
|
||||
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
|
||||
|
||||
def test_no_lock_when_ee_absent(
|
||||
self,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""No advisory lock should be acquired when the EE check is absent."""
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=None,
|
||||
):
|
||||
result = _check_seat_availability(mock_dal)
|
||||
|
||||
assert result is None
|
||||
mock_dal.session.execute.assert_not_called()
|
||||
|
||||
@@ -70,10 +70,6 @@ spec:
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9094
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_heavy.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -28,7 +28,7 @@ dependencies = [
|
||||
"kubernetes>=31.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
[project.optional-dependencies]
|
||||
# Main backend application dependencies
|
||||
backend = [
|
||||
"aiohttp==3.13.4",
|
||||
@@ -195,9 +195,6 @@ model_server = [
|
||||
"sentry-sdk[fastapi,celery,starlette]==2.14.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
default-groups = ["backend", "dev", "ee", "model_server"]
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
@@ -233,7 +230,7 @@ follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["tools/ods"]
|
||||
members = ["backend", "tools/ods"]
|
||||
|
||||
[tool.basedpyright]
|
||||
include = ["backend"]
|
||||
|
||||
310
uv.lock
generated
310
uv.lock
generated
@@ -14,6 +14,12 @@ resolution-markers = [
|
||||
"python_full_version < '3.12' and sys_platform != 'win32'",
|
||||
]
|
||||
|
||||
[manifest]
|
||||
members = [
|
||||
"onyx",
|
||||
"onyx-backend",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
version = "1.6.0"
|
||||
@@ -4228,7 +4234,7 @@ dependencies = [
|
||||
{ name = "voyageai" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
[package.optional-dependencies]
|
||||
backend = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "alembic" },
|
||||
@@ -4382,175 +4388,179 @@ model-server = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'model-server'", specifier = "==1.6.0" },
|
||||
{ name = "agent-client-protocol", specifier = ">=0.7.1" },
|
||||
{ name = "aioboto3", specifier = "==15.1.0" },
|
||||
{ name = "aiohttp", marker = "extra == 'backend'", specifier = "==3.13.4" },
|
||||
{ name = "alembic", marker = "extra == 'backend'", specifier = "==1.10.4" },
|
||||
{ name = "asana", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "asyncpg", marker = "extra == 'backend'", specifier = "==0.30.0" },
|
||||
{ name = "atlassian-python-api", marker = "extra == 'backend'", specifier = "==3.41.16" },
|
||||
{ name = "azure-cognitiveservices-speech", marker = "extra == 'backend'", specifier = "==1.38.0" },
|
||||
{ name = "beautifulsoup4", marker = "extra == 'backend'", specifier = "==4.12.3" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = "==25.1.0" },
|
||||
{ name = "boto3", marker = "extra == 'backend'", specifier = "==1.39.11" },
|
||||
{ name = "boto3-stubs", extras = ["s3"], marker = "extra == 'backend'", specifier = "==1.39.11" },
|
||||
{ name = "braintrust", marker = "extra == 'backend'", specifier = "==0.3.9" },
|
||||
{ name = "brotli", specifier = ">=1.2.0" },
|
||||
{ name = "celery", marker = "extra == 'backend'", specifier = "==5.5.1" },
|
||||
{ name = "celery-types", marker = "extra == 'dev'", specifier = "==0.19.0" },
|
||||
{ name = "chardet", marker = "extra == 'backend'", specifier = "==5.2.0" },
|
||||
{ name = "chonkie", marker = "extra == 'backend'", specifier = "==1.0.10" },
|
||||
{ name = "claude-agent-sdk", specifier = ">=0.1.19" },
|
||||
{ name = "cohere", specifier = "==5.6.1" },
|
||||
{ name = "dask", marker = "extra == 'backend'", specifier = "==2026.1.1" },
|
||||
{ name = "ddtrace", marker = "extra == 'backend'", specifier = "==3.10.0" },
|
||||
{ name = "discord-py", specifier = "==2.4.0" },
|
||||
{ name = "discord-py", marker = "extra == 'backend'", specifier = "==2.4.0" },
|
||||
{ name = "distributed", marker = "extra == 'backend'", specifier = "==2026.1.1" },
|
||||
{ name = "dropbox", marker = "extra == 'backend'", specifier = "==12.0.2" },
|
||||
{ name = "einops", marker = "extra == 'model-server'", specifier = "==0.8.1" },
|
||||
{ name = "exa-py", marker = "extra == 'backend'", specifier = "==1.15.4" },
|
||||
{ name = "faker", marker = "extra == 'dev'", specifier = "==40.1.2" },
|
||||
{ name = "fastapi", specifier = "==0.133.1" },
|
||||
{ name = "fastapi-limiter", marker = "extra == 'backend'", specifier = "==0.1.6" },
|
||||
{ name = "fastapi-users", marker = "extra == 'backend'", specifier = "==15.0.4" },
|
||||
{ name = "fastapi-users-db-sqlalchemy", marker = "extra == 'backend'", specifier = "==7.0.0" },
|
||||
{ name = "fastmcp", marker = "extra == 'backend'", specifier = "==3.2.0" },
|
||||
{ name = "filelock", marker = "extra == 'backend'", specifier = "==3.20.3" },
|
||||
{ name = "google-api-python-client", marker = "extra == 'backend'", specifier = "==2.86.0" },
|
||||
{ name = "google-auth-httplib2", marker = "extra == 'backend'", specifier = "==0.1.0" },
|
||||
{ name = "google-auth-oauthlib", marker = "extra == 'backend'", specifier = "==1.0.0" },
|
||||
{ name = "google-genai", specifier = "==1.52.0" },
|
||||
{ name = "hatchling", marker = "extra == 'dev'", specifier = "==1.28.0" },
|
||||
{ name = "httpcore", marker = "extra == 'backend'", specifier = "==1.0.9" },
|
||||
{ name = "httpx", extras = ["http2"], marker = "extra == 'backend'", specifier = "==0.28.1" },
|
||||
{ name = "httpx-oauth", marker = "extra == 'backend'", specifier = "==0.15.1" },
|
||||
{ name = "hubspot-api-client", marker = "extra == 'backend'", specifier = "==11.1.0" },
|
||||
{ name = "huggingface-hub", marker = "extra == 'backend'", specifier = "==0.35.3" },
|
||||
{ name = "inflection", marker = "extra == 'backend'", specifier = "==0.5.1" },
|
||||
{ name = "ipykernel", marker = "extra == 'dev'", specifier = "==6.29.5" },
|
||||
{ name = "jira", marker = "extra == 'backend'", specifier = "==3.10.5" },
|
||||
{ name = "jsonref", marker = "extra == 'backend'", specifier = "==1.1.0" },
|
||||
{ name = "kubernetes", specifier = ">=31.0.0" },
|
||||
{ name = "kubernetes", marker = "extra == 'backend'", specifier = "==31.0.0" },
|
||||
{ name = "langchain-core", marker = "extra == 'backend'", specifier = "==1.2.22" },
|
||||
{ name = "langfuse", marker = "extra == 'backend'", specifier = "==3.10.0" },
|
||||
{ name = "lazy-imports", marker = "extra == 'backend'", specifier = "==1.0.1" },
|
||||
{ name = "litellm", specifier = "==1.81.6" },
|
||||
{ name = "lxml", marker = "extra == 'backend'", specifier = "==5.3.0" },
|
||||
{ name = "mako", marker = "extra == 'backend'", specifier = "==1.2.4" },
|
||||
{ name = "manygo", marker = "extra == 'dev'", specifier = "==0.2.0" },
|
||||
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], marker = "extra == 'backend'", specifier = "==0.1.2" },
|
||||
{ name = "matplotlib", marker = "extra == 'dev'", specifier = "==3.10.8" },
|
||||
{ name = "mcp", extras = ["cli"], marker = "extra == 'backend'", specifier = "==1.26.0" },
|
||||
{ name = "mistune", marker = "extra == 'backend'", specifier = "==3.2.0" },
|
||||
{ name = "msal", marker = "extra == 'backend'", specifier = "==1.34.0" },
|
||||
{ name = "msoffcrypto-tool", marker = "extra == 'backend'", specifier = "==5.4.2" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = "==1.13.0" },
|
||||
{ name = "mypy-extensions", marker = "extra == 'dev'", specifier = "==1.0.0" },
|
||||
{ name = "nest-asyncio", marker = "extra == 'backend'", specifier = "==1.6.0" },
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.7.3" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
{ name = "openpyxl", marker = "extra == 'backend'", specifier = "==3.0.10" },
|
||||
{ name = "opensearch-py", marker = "extra == 'backend'", specifier = "==3.0.0" },
|
||||
{ name = "opentelemetry-proto", marker = "extra == 'backend'", specifier = ">=1.39.0" },
|
||||
{ name = "pandas-stubs", marker = "extra == 'dev'", specifier = "~=2.3.3" },
|
||||
{ name = "passlib", marker = "extra == 'backend'", specifier = "==1.7.4" },
|
||||
{ name = "playwright", marker = "extra == 'backend'", specifier = "==1.55.0" },
|
||||
{ name = "posthog", marker = "extra == 'ee'", specifier = "==3.7.4" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = "==3.2.2" },
|
||||
{ name = "prometheus-client", specifier = ">=0.21.1" },
|
||||
{ name = "prometheus-fastapi-instrumentator", specifier = "==7.1.0" },
|
||||
{ name = "psutil", marker = "extra == 'backend'", specifier = "==7.1.3" },
|
||||
{ name = "psycopg2-binary", marker = "extra == 'backend'", specifier = "==2.9.9" },
|
||||
{ name = "puremagic", marker = "extra == 'backend'", specifier = "==1.28" },
|
||||
{ name = "pyairtable", marker = "extra == 'backend'", specifier = "==3.0.1" },
|
||||
{ name = "pycryptodome", marker = "extra == 'backend'", specifier = "==3.19.1" },
|
||||
{ name = "pydantic", specifier = "==2.11.7" },
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.9.2" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
{ name = "pytest-dotenv", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "pytest-mock", marker = "extra == 'backend'", specifier = "==3.12.0" },
|
||||
{ name = "pytest-playwright", marker = "extra == 'backend'", specifier = "==0.7.0" },
|
||||
{ name = "pytest-repeat", marker = "extra == 'dev'", specifier = "==0.9.4" },
|
||||
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = "==3.8.0" },
|
||||
{ name = "python-dateutil", marker = "extra == 'backend'", specifier = "==2.8.2" },
|
||||
{ name = "python-docx", marker = "extra == 'backend'", specifier = "==1.1.2" },
|
||||
{ name = "python-dotenv", marker = "extra == 'backend'", specifier = "==1.1.1" },
|
||||
{ name = "python-gitlab", marker = "extra == 'backend'", specifier = "==5.6.0" },
|
||||
{ name = "python-multipart", marker = "extra == 'backend'", specifier = "==0.0.22" },
|
||||
{ name = "python-pptx", marker = "extra == 'backend'", specifier = "==0.6.23" },
|
||||
{ name = "python3-saml", marker = "extra == 'backend'", specifier = "==1.15.0" },
|
||||
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
|
||||
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.33.0" },
|
||||
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
|
||||
{ name = "retry", specifier = "==0.9.2" },
|
||||
{ name = "rfc3986", marker = "extra == 'backend'", specifier = "==1.5.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = "==0.12.0" },
|
||||
{ name = "safetensors", marker = "extra == 'model-server'", specifier = "==0.5.3" },
|
||||
{ name = "sendgrid", marker = "extra == 'backend'", specifier = "==6.12.5" },
|
||||
{ name = "sentence-transformers", marker = "extra == 'model-server'", specifier = "==4.0.2" },
|
||||
{ name = "sentry-sdk", specifier = "==2.14.0" },
|
||||
{ name = "sentry-sdk", extras = ["fastapi", "celery", "starlette"], marker = "extra == 'model-server'", specifier = "==2.14.0" },
|
||||
{ name = "shapely", marker = "extra == 'backend'", specifier = "==2.0.6" },
|
||||
{ name = "simple-salesforce", marker = "extra == 'backend'", specifier = "==1.12.6" },
|
||||
{ name = "slack-sdk", marker = "extra == 'backend'", specifier = "==3.20.2" },
|
||||
{ name = "sqlalchemy", extras = ["mypy"], marker = "extra == 'backend'", specifier = "==2.0.15" },
|
||||
{ name = "starlette", marker = "extra == 'backend'", specifier = "==0.49.3" },
|
||||
{ name = "stripe", marker = "extra == 'backend'", specifier = "==10.12.0" },
|
||||
{ name = "supervisor", marker = "extra == 'backend'", specifier = "==4.3.0" },
|
||||
{ name = "tiktoken", marker = "extra == 'backend'", specifier = "==0.7.0" },
|
||||
{ name = "timeago", marker = "extra == 'backend'", specifier = "==1.0.16" },
|
||||
{ name = "torch", marker = "extra == 'model-server'", specifier = "==2.9.1" },
|
||||
{ name = "trafilatura", marker = "extra == 'backend'", specifier = "==1.12.2" },
|
||||
{ name = "transformers", marker = "extra == 'model-server'", specifier = "==4.53.0" },
|
||||
{ name = "types-beautifulsoup4", marker = "extra == 'dev'", specifier = "==4.12.0.3" },
|
||||
{ name = "types-html5lib", marker = "extra == 'dev'", specifier = "==1.1.11.13" },
|
||||
{ name = "types-oauthlib", marker = "extra == 'dev'", specifier = "==3.2.0.9" },
|
||||
{ name = "types-openpyxl", marker = "extra == 'backend'", specifier = "==3.0.4.7" },
|
||||
{ name = "types-passlib", marker = "extra == 'dev'", specifier = "==1.7.7.20240106" },
|
||||
{ name = "types-pillow", marker = "extra == 'dev'", specifier = "==10.2.0.20240822" },
|
||||
{ name = "types-psutil", marker = "extra == 'dev'", specifier = "==7.1.3.20251125" },
|
||||
{ name = "types-psycopg2", marker = "extra == 'dev'", specifier = "==2.9.21.10" },
|
||||
{ name = "types-python-dateutil", marker = "extra == 'dev'", specifier = "==2.8.19.13" },
|
||||
{ name = "types-pytz", marker = "extra == 'dev'", specifier = "==2023.3.1.1" },
|
||||
{ name = "types-pyyaml", marker = "extra == 'dev'", specifier = "==6.0.12.11" },
|
||||
{ name = "types-regex", marker = "extra == 'dev'", specifier = "==2023.3.23.1" },
|
||||
{ name = "types-requests", marker = "extra == 'dev'", specifier = "==2.32.0.20250328" },
|
||||
{ name = "types-retry", marker = "extra == 'dev'", specifier = "==0.9.9.3" },
|
||||
{ name = "types-setuptools", marker = "extra == 'dev'", specifier = "==68.0.0.3" },
|
||||
{ name = "unstructured", marker = "extra == 'backend'", specifier = "==0.18.27" },
|
||||
{ name = "unstructured-client", marker = "extra == 'backend'", specifier = "==0.42.6" },
|
||||
{ name = "urllib3", marker = "extra == 'backend'", specifier = "==2.6.3" },
|
||||
{ name = "uvicorn", specifier = "==0.35.0" },
|
||||
{ name = "voyageai", specifier = "==0.2.3" },
|
||||
{ name = "xmlsec", marker = "extra == 'backend'", specifier = "==1.3.14" },
|
||||
{ name = "zizmor", marker = "extra == 'dev'", specifier = "==1.18.0" },
|
||||
{ name = "zulip", marker = "extra == 'backend'", specifier = "==0.8.2" },
|
||||
]
|
||||
provides-extras = ["backend", "dev", "ee", "model-server"]
|
||||
|
||||
[[package]]
|
||||
name = "onyx-backend"
|
||||
version = "0.0.0"
|
||||
source = { virtual = "backend" }
|
||||
dependencies = [
|
||||
{ name = "onyx", extra = ["backend", "dev", "ee"] },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
backend = [
|
||||
{ name = "aiohttp", specifier = "==3.13.4" },
|
||||
{ name = "alembic", specifier = "==1.10.4" },
|
||||
{ name = "asana", specifier = "==5.0.8" },
|
||||
{ name = "asyncpg", specifier = "==0.30.0" },
|
||||
{ name = "atlassian-python-api", specifier = "==3.41.16" },
|
||||
{ name = "azure-cognitiveservices-speech", specifier = "==1.38.0" },
|
||||
{ name = "beautifulsoup4", specifier = "==4.12.3" },
|
||||
{ name = "boto3", specifier = "==1.39.11" },
|
||||
{ name = "boto3-stubs", extras = ["s3"], specifier = "==1.39.11" },
|
||||
{ name = "braintrust", specifier = "==0.3.9" },
|
||||
{ name = "celery", specifier = "==5.5.1" },
|
||||
{ name = "chardet", specifier = "==5.2.0" },
|
||||
{ name = "chonkie", specifier = "==1.0.10" },
|
||||
{ name = "dask", specifier = "==2026.1.1" },
|
||||
{ name = "ddtrace", specifier = "==3.10.0" },
|
||||
{ name = "discord-py", specifier = "==2.4.0" },
|
||||
{ name = "distributed", specifier = "==2026.1.1" },
|
||||
{ name = "dropbox", specifier = "==12.0.2" },
|
||||
{ name = "exa-py", specifier = "==1.15.4" },
|
||||
{ name = "fastapi-limiter", specifier = "==0.1.6" },
|
||||
{ name = "fastapi-users", specifier = "==15.0.4" },
|
||||
{ name = "fastapi-users-db-sqlalchemy", specifier = "==7.0.0" },
|
||||
{ name = "fastmcp", specifier = "==3.2.0" },
|
||||
{ name = "filelock", specifier = "==3.20.3" },
|
||||
{ name = "google-api-python-client", specifier = "==2.86.0" },
|
||||
{ name = "google-auth-httplib2", specifier = "==0.1.0" },
|
||||
{ name = "google-auth-oauthlib", specifier = "==1.0.0" },
|
||||
{ name = "httpcore", specifier = "==1.0.9" },
|
||||
{ name = "httpx", extras = ["http2"], specifier = "==0.28.1" },
|
||||
{ name = "httpx-oauth", specifier = "==0.15.1" },
|
||||
{ name = "hubspot-api-client", specifier = "==11.1.0" },
|
||||
{ name = "huggingface-hub", specifier = "==0.35.3" },
|
||||
{ name = "inflection", specifier = "==0.5.1" },
|
||||
{ name = "jira", specifier = "==3.10.5" },
|
||||
{ name = "jsonref", specifier = "==1.1.0" },
|
||||
{ name = "kubernetes", specifier = "==31.0.0" },
|
||||
{ name = "langchain-core", specifier = "==1.2.22" },
|
||||
{ name = "langfuse", specifier = "==3.10.0" },
|
||||
{ name = "lazy-imports", specifier = "==1.0.1" },
|
||||
{ name = "lxml", specifier = "==5.3.0" },
|
||||
{ name = "mako", specifier = "==1.2.4" },
|
||||
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], specifier = "==0.1.2" },
|
||||
{ name = "mcp", extras = ["cli"], specifier = "==1.26.0" },
|
||||
{ name = "mistune", specifier = "==3.2.0" },
|
||||
{ name = "msal", specifier = "==1.34.0" },
|
||||
{ name = "msoffcrypto-tool", specifier = "==5.4.2" },
|
||||
{ name = "nest-asyncio", specifier = "==1.6.0" },
|
||||
{ name = "oauthlib", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", specifier = "==2.6.2" },
|
||||
{ name = "openinference-instrumentation", specifier = "==0.1.42" },
|
||||
{ name = "openpyxl", specifier = "==3.0.10" },
|
||||
{ name = "opensearch-py", specifier = "==3.0.0" },
|
||||
{ name = "opentelemetry-proto", specifier = ">=1.39.0" },
|
||||
{ name = "passlib", specifier = "==1.7.4" },
|
||||
{ name = "playwright", specifier = "==1.55.0" },
|
||||
{ name = "psutil", specifier = "==7.1.3" },
|
||||
{ name = "psycopg2-binary", specifier = "==2.9.9" },
|
||||
{ name = "puremagic", specifier = "==1.28" },
|
||||
{ name = "pyairtable", specifier = "==3.0.1" },
|
||||
{ name = "pycryptodome", specifier = "==3.19.1" },
|
||||
{ name = "pygithub", specifier = "==2.5.0" },
|
||||
{ name = "pympler", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", specifier = "==6.9.2" },
|
||||
{ name = "pytest-mock", specifier = "==3.12.0" },
|
||||
{ name = "pytest-playwright", specifier = "==0.7.0" },
|
||||
{ name = "python-dateutil", specifier = "==2.8.2" },
|
||||
{ name = "python-docx", specifier = "==1.1.2" },
|
||||
{ name = "python-dotenv", specifier = "==1.1.1" },
|
||||
{ name = "python-gitlab", specifier = "==5.6.0" },
|
||||
{ name = "python-multipart", specifier = "==0.0.22" },
|
||||
{ name = "python-pptx", specifier = "==0.6.23" },
|
||||
{ name = "python3-saml", specifier = "==1.15.0" },
|
||||
{ name = "pywikibot", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", specifier = "==3.13.0" },
|
||||
{ name = "redis", specifier = "==5.0.8" },
|
||||
{ name = "requests", specifier = "==2.33.0" },
|
||||
{ name = "requests-oauthlib", specifier = "==1.3.1" },
|
||||
{ name = "rfc3986", specifier = "==1.5.0" },
|
||||
{ name = "sendgrid", specifier = "==6.12.5" },
|
||||
{ name = "shapely", specifier = "==2.0.6" },
|
||||
{ name = "simple-salesforce", specifier = "==1.12.6" },
|
||||
{ name = "slack-sdk", specifier = "==3.20.2" },
|
||||
{ name = "sqlalchemy", extras = ["mypy"], specifier = "==2.0.15" },
|
||||
{ name = "starlette", specifier = "==0.49.3" },
|
||||
{ name = "stripe", specifier = "==10.12.0" },
|
||||
{ name = "supervisor", specifier = "==4.3.0" },
|
||||
{ name = "tiktoken", specifier = "==0.7.0" },
|
||||
{ name = "timeago", specifier = "==1.0.16" },
|
||||
{ name = "trafilatura", specifier = "==1.12.2" },
|
||||
{ name = "types-openpyxl", specifier = "==3.0.4.7" },
|
||||
{ name = "unstructured", specifier = "==0.18.27" },
|
||||
{ name = "unstructured-client", specifier = "==0.42.6" },
|
||||
{ name = "urllib3", specifier = "==2.6.3" },
|
||||
{ name = "xmlsec", specifier = "==1.3.14" },
|
||||
{ name = "zulip", specifier = "==0.8.2" },
|
||||
]
|
||||
dev = [
|
||||
{ name = "black", specifier = "==25.1.0" },
|
||||
{ name = "celery-types", specifier = "==0.19.0" },
|
||||
{ name = "faker", specifier = "==40.1.2" },
|
||||
{ name = "hatchling", specifier = "==1.28.0" },
|
||||
{ name = "ipykernel", specifier = "==6.29.5" },
|
||||
{ name = "manygo", specifier = "==0.2.0" },
|
||||
{ name = "matplotlib", specifier = "==3.10.8" },
|
||||
{ name = "mypy", specifier = "==1.13.0" },
|
||||
{ name = "mypy-extensions", specifier = "==1.0.0" },
|
||||
{ name = "onyx-devtools", specifier = "==0.7.3" },
|
||||
{ name = "openapi-generator-cli", specifier = "==7.17.0" },
|
||||
{ name = "pandas-stubs", specifier = "~=2.3.3" },
|
||||
{ name = "pre-commit", specifier = "==3.2.2" },
|
||||
{ name = "pytest", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", specifier = "==1.3.0" },
|
||||
{ name = "pytest-dotenv", specifier = "==0.5.2" },
|
||||
{ name = "pytest-repeat", specifier = "==0.9.4" },
|
||||
{ name = "pytest-xdist", specifier = "==3.8.0" },
|
||||
{ name = "release-tag", specifier = "==0.5.2" },
|
||||
{ name = "reorder-python-imports-black", specifier = "==3.14.0" },
|
||||
{ name = "ruff", specifier = "==0.12.0" },
|
||||
{ name = "types-beautifulsoup4", specifier = "==4.12.0.3" },
|
||||
{ name = "types-html5lib", specifier = "==1.1.11.13" },
|
||||
{ name = "types-oauthlib", specifier = "==3.2.0.9" },
|
||||
{ name = "types-passlib", specifier = "==1.7.7.20240106" },
|
||||
{ name = "types-pillow", specifier = "==10.2.0.20240822" },
|
||||
{ name = "types-psutil", specifier = "==7.1.3.20251125" },
|
||||
{ name = "types-psycopg2", specifier = "==2.9.21.10" },
|
||||
{ name = "types-python-dateutil", specifier = "==2.8.19.13" },
|
||||
{ name = "types-pytz", specifier = "==2023.3.1.1" },
|
||||
{ name = "types-pyyaml", specifier = "==6.0.12.11" },
|
||||
{ name = "types-regex", specifier = "==2023.3.23.1" },
|
||||
{ name = "types-requests", specifier = "==2.32.0.20250328" },
|
||||
{ name = "types-retry", specifier = "==0.9.9.3" },
|
||||
{ name = "types-setuptools", specifier = "==68.0.0.3" },
|
||||
{ name = "zizmor", specifier = "==1.18.0" },
|
||||
]
|
||||
ee = [{ name = "posthog", specifier = "==3.7.4" }]
|
||||
model-server = [
|
||||
{ name = "accelerate", specifier = "==1.6.0" },
|
||||
{ name = "einops", specifier = "==0.8.1" },
|
||||
{ name = "numpy", specifier = "==2.4.1" },
|
||||
{ name = "safetensors", specifier = "==0.5.3" },
|
||||
{ name = "sentence-transformers", specifier = "==4.0.2" },
|
||||
{ name = "sentry-sdk", extras = ["fastapi", "celery", "starlette"], specifier = "==2.14.0" },
|
||||
{ name = "torch", specifier = "==2.9.1" },
|
||||
{ name = "transformers", specifier = "==4.53.0" },
|
||||
]
|
||||
[package.metadata]
|
||||
requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable = "." }]
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
|
||||
16
web/package-lock.json
generated
16
web/package-lock.json
generated
@@ -47,7 +47,6 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
@@ -8844,15 +8843,6 @@
|
||||
"react": ">= 16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/copy-to-clipboard": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz",
|
||||
"integrity": "sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"toggle-selection": "^1.0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/core-js": {
|
||||
"version": "3.46.0",
|
||||
"hasInstallScript": true,
|
||||
@@ -17436,12 +17426,6 @@
|
||||
"node": ">=8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/toggle-selection": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/toggle-selection/-/toggle-selection-1.0.6.tgz",
|
||||
"integrity": "sha512-BiZS+C1OS8g/q2RRbJmy59xpyghNBqrr6k5L/uKBGRsTfxmu3ffiRnd8mlGPUVayg8pvfi5urfnu8TU7DVOkLQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/toposort": {
|
||||
"version": "2.0.2",
|
||||
"license": "MIT"
|
||||
|
||||
@@ -65,7 +65,6 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
|
||||
@@ -17,7 +17,6 @@ import DocumentSetCard from "@/sections/cards/DocumentSetCard";
|
||||
import CollapsibleSection from "@/app/admin/agents/CollapsibleSection";
|
||||
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { StandardAnswerCategoryDropdownField } from "@/components/standardAnswers/StandardAnswerCategoryDropdown";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import { RadioGroup } from "@/components/ui/radio-group";
|
||||
import { RadioGroupItemField } from "@/components/ui/RadioGroupItemField";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
@@ -127,24 +126,6 @@ export function SlackChannelConfigFormFields({
|
||||
return documentSets.filter((ds) => !documentSetContainsSync(ds));
|
||||
}, [documentSets]);
|
||||
|
||||
const searchAgentOptions = useMemo(
|
||||
() =>
|
||||
availableAgents.map((persona) => ({
|
||||
label: persona.name,
|
||||
value: String(persona.id),
|
||||
})),
|
||||
[availableAgents]
|
||||
);
|
||||
|
||||
const nonSearchAgentOptions = useMemo(
|
||||
() =>
|
||||
nonSearchAgents.map((persona) => ({
|
||||
label: persona.name,
|
||||
value: String(persona.id),
|
||||
})),
|
||||
[nonSearchAgents]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const invalidSelected = values.document_sets.filter((dsId: number) =>
|
||||
unselectableSets.some((us) => us.id === dsId)
|
||||
@@ -374,14 +355,12 @@ export function SlackChannelConfigFormFields({
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<InputComboBox
|
||||
placeholder="Search for an agent..."
|
||||
value={String(values.persona_id ?? "")}
|
||||
onValueChange={(val) =>
|
||||
setFieldValue("persona_id", val ? Number(val) : null)
|
||||
}
|
||||
options={searchAgentOptions}
|
||||
strict
|
||||
<SelectorFormField
|
||||
name="persona_id"
|
||||
options={availableAgents.map((persona) => ({
|
||||
name: persona.name,
|
||||
value: persona.id,
|
||||
}))}
|
||||
/>
|
||||
{viewSyncEnabledAgents && syncEnabledAgents.length > 0 && (
|
||||
<div className="mt-4">
|
||||
@@ -440,14 +419,12 @@ export function SlackChannelConfigFormFields({
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<InputComboBox
|
||||
placeholder="Search for an agent..."
|
||||
value={String(values.persona_id ?? "")}
|
||||
onValueChange={(val) =>
|
||||
setFieldValue("persona_id", val ? Number(val) : null)
|
||||
}
|
||||
options={nonSearchAgentOptions}
|
||||
strict
|
||||
<SelectorFormField
|
||||
name="persona_id"
|
||||
options={nonSearchAgents.map((persona) => ({
|
||||
name: persona.name,
|
||||
value: persona.id,
|
||||
}))}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { defaultTailwindCSS } from "@/components/icons/icons";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import { IconProps } from "@opal/types";
|
||||
|
||||
export interface ModelIconProps extends IconProps {
|
||||
|
||||
@@ -1 +1 @@
|
||||
export { default } from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";
|
||||
|
||||
@@ -5,7 +5,7 @@ import { Button } from "@opal/components";
|
||||
import { Text } from "@opal/components";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { SvgEyeOff, SvgX } from "@opal/icons";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import AgentMessage, {
|
||||
AgentMessageProps,
|
||||
} from "@/app/app/message/messageComponents/AgentMessage";
|
||||
@@ -28,8 +28,6 @@ export interface MultiModelPanelProps {
|
||||
isNonPreferredInSelection: boolean;
|
||||
/** Callback when user clicks this panel to select as preferred */
|
||||
onSelect: () => void;
|
||||
/** Callback to deselect this panel as preferred */
|
||||
onDeselect?: () => void;
|
||||
/** Callback to hide/show this panel */
|
||||
onToggleVisibility: () => void;
|
||||
/** Props to pass through to AgentMessage */
|
||||
@@ -65,7 +63,6 @@ export default function MultiModelPanel({
|
||||
isHidden,
|
||||
isNonPreferredInSelection,
|
||||
onSelect,
|
||||
onDeselect,
|
||||
onToggleVisibility,
|
||||
agentMessageProps,
|
||||
errorMessage,
|
||||
@@ -96,25 +93,11 @@ export default function MultiModelPanel({
|
||||
rightChildren={
|
||||
<div className="flex items-center gap-1 px-2">
|
||||
{isPreferred && (
|
||||
<>
|
||||
<span className="text-action-link-05 shrink-0">
|
||||
<Text font="secondary-body" color="inherit" nowrap>
|
||||
Preferred Response
|
||||
</Text>
|
||||
</span>
|
||||
{onDeselect && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgX}
|
||||
size="sm"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onDeselect();
|
||||
}}
|
||||
tooltip="Deselect preferred response"
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
<span className="text-action-link-05 shrink-0">
|
||||
<Text font="secondary-body" color="inherit" nowrap>
|
||||
Preferred Response
|
||||
</Text>
|
||||
</span>
|
||||
)}
|
||||
{!isPreferred && (
|
||||
<Button
|
||||
@@ -163,7 +146,6 @@ export default function MultiModelPanel({
|
||||
<AgentMessage
|
||||
{...agentMessageProps}
|
||||
hideFooter={isNonPreferredInSelection}
|
||||
disableTTS
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -30,7 +30,7 @@ const SELECTION_PANEL_W = 400;
|
||||
// Compact width for hidden panels in the carousel track
|
||||
const HIDDEN_PANEL_W = 220;
|
||||
// Generation-mode panel widths (from Figma)
|
||||
const GEN_PANEL_W_2 = 720; // 2 panels side-by-side
|
||||
const GEN_PANEL_W_2 = 640; // 2 panels side-by-side
|
||||
const GEN_PANEL_W_3 = 436; // 3 panels side-by-side
|
||||
// Gap between panels — matches CSS gap-6 (24px)
|
||||
const PANEL_GAP = 24;
|
||||
@@ -64,31 +64,14 @@ export default function MultiModelResponseView({
|
||||
onMessageSelection,
|
||||
onHiddenPanelsChange,
|
||||
}: MultiModelResponseViewProps) {
|
||||
// Initialize preferredIndex from the backend's preferred_response_id when
|
||||
// loading an existing conversation.
|
||||
const [preferredIndex, setPreferredIndex] = useState<number | null>(() => {
|
||||
if (!parentMessage?.preferredResponseId) return null;
|
||||
const match = responses.find(
|
||||
(r) => r.messageId === parentMessage.preferredResponseId
|
||||
);
|
||||
return match?.modelIndex ?? null;
|
||||
});
|
||||
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
|
||||
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
|
||||
// Controls animation: false = panels at start position, true = panels at peek position
|
||||
const [selectionEntered, setSelectionEntered] = useState(
|
||||
() => preferredIndex !== null
|
||||
);
|
||||
// Tracks the deselect animation timeout so it can be cancelled if the user
|
||||
// re-selects a panel during the 450ms animation window.
|
||||
const deselectTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
// True while the reverse animation is playing (deselect → back to equal panels)
|
||||
const [selectionExiting, setSelectionExiting] = useState(false);
|
||||
const [selectionEntered, setSelectionEntered] = useState(false);
|
||||
// Measures the overflow-hidden carousel container for responsive preferred-panel sizing.
|
||||
const [trackContainerW, setTrackContainerW] = useState(0);
|
||||
const roRef = useRef<ResizeObserver | null>(null);
|
||||
const trackContainerElRef = useRef<HTMLDivElement | null>(null);
|
||||
const trackContainerRef = useCallback((el: HTMLDivElement | null) => {
|
||||
trackContainerElRef.current = el;
|
||||
if (roRef.current) {
|
||||
roRef.current.disconnect();
|
||||
roRef.current = null;
|
||||
@@ -107,9 +90,6 @@ export default function MultiModelResponseView({
|
||||
number | null
|
||||
>(null);
|
||||
const preferredRoRef = useRef<ResizeObserver | null>(null);
|
||||
// Refs to each panel wrapper for height animation on deselect
|
||||
const panelElsRef = useRef<Map<number, HTMLDivElement>>(new Map());
|
||||
|
||||
// Tracks which non-preferred panels overflow the preferred height cap
|
||||
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
|
||||
new Set()
|
||||
@@ -172,48 +152,15 @@ export default function MultiModelResponseView({
|
||||
const handleSelectPreferred = useCallback(
|
||||
(modelIndex: number) => {
|
||||
if (isGenerating) return;
|
||||
|
||||
// Cancel any pending deselect animation so it doesn't overwrite this selection
|
||||
if (deselectTimeoutRef.current !== null) {
|
||||
clearTimeout(deselectTimeoutRef.current);
|
||||
deselectTimeoutRef.current = null;
|
||||
setSelectionExiting(false);
|
||||
}
|
||||
|
||||
// Only freeze scroll when entering selection mode for the first time.
|
||||
// When switching preferred within selection mode, panels are already
|
||||
// capped and the track just slides — no height changes to worry about.
|
||||
const alreadyInSelection = preferredIndex !== null;
|
||||
if (!alreadyInSelection) {
|
||||
const scrollContainer = trackContainerElRef.current?.closest(
|
||||
"[data-chat-scroll]"
|
||||
) as HTMLElement | null;
|
||||
const scrollTop = scrollContainer?.scrollTop ?? 0;
|
||||
if (scrollContainer) scrollContainer.style.overflow = "hidden";
|
||||
|
||||
setTimeout(() => {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollTop = scrollTop;
|
||||
requestAnimationFrame(() => {
|
||||
requestAnimationFrame(() => {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollTop = scrollTop;
|
||||
scrollContainer.style.overflow = "";
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}, 450);
|
||||
}
|
||||
|
||||
setPreferredIndex(modelIndex);
|
||||
const response = responses.find((r) => r.modelIndex === modelIndex);
|
||||
if (!response) return;
|
||||
if (onMessageSelection) {
|
||||
onMessageSelection(response.nodeId);
|
||||
}
|
||||
|
||||
// Persist preferred response + sync `latestChildNodeId`. Backend's
|
||||
// `set_preferred_response` updates `latest_child_message_id`; if the
|
||||
// frontend chain walk disagrees, the next follow-up fails with
|
||||
// "not on the latest mainline".
|
||||
// Persist preferred response to backend + update local tree so the
|
||||
// input bar unblocks (awaitingPreferredSelection clears).
|
||||
if (parentMessage?.messageId && response.messageId && currentSessionId) {
|
||||
setPreferredResponse(parentMessage.messageId, response.messageId).catch(
|
||||
(err) => console.error("Failed to persist preferred response:", err)
|
||||
@@ -229,7 +176,6 @@ export default function MultiModelResponseView({
|
||||
updated.set(parentMessage.nodeId, {
|
||||
...userMsg,
|
||||
preferredResponseId: response.messageId,
|
||||
latestChildNodeId: response.nodeId,
|
||||
});
|
||||
updateSessionMessageTree(currentSessionId, updated);
|
||||
}
|
||||
@@ -239,111 +185,17 @@ export default function MultiModelResponseView({
|
||||
[
|
||||
isGenerating,
|
||||
responses,
|
||||
preferredIndex,
|
||||
onMessageSelection,
|
||||
parentMessage,
|
||||
currentSessionId,
|
||||
updateSessionMessageTree,
|
||||
]
|
||||
);
|
||||
|
||||
// NOTE: Deselect only clears the local tree — no backend call to clear
|
||||
// preferred_response_id. The SetPreferredResponseRequest model doesn't
|
||||
// accept null. A backend endpoint for clearing preference would be needed
|
||||
// if deselect should persist across reloads.
|
||||
const handleDeselectPreferred = useCallback(() => {
|
||||
const scrollContainer = trackContainerElRef.current?.closest(
|
||||
"[data-chat-scroll]"
|
||||
) as HTMLElement | null;
|
||||
|
||||
// Animate panels back to equal positions, then clear preferred after transition
|
||||
setSelectionExiting(true);
|
||||
setSelectionEntered(false);
|
||||
deselectTimeoutRef.current = setTimeout(() => {
|
||||
deselectTimeoutRef.current = null;
|
||||
const scrollTop = scrollContainer?.scrollTop ?? 0;
|
||||
if (scrollContainer) scrollContainer.style.overflow = "hidden";
|
||||
|
||||
// Before clearing state, animate each capped panel's height from
|
||||
// its current clientHeight to its natural scrollHeight.
|
||||
const animations: Animation[] = [];
|
||||
panelElsRef.current.forEach((el, modelIndex) => {
|
||||
if (modelIndex === preferredIndex) return;
|
||||
if (hiddenPanels.has(modelIndex)) return;
|
||||
const from = el.clientHeight;
|
||||
const to = el.scrollHeight;
|
||||
if (to <= from) return;
|
||||
// Lock current height, remove maxHeight cap, then animate
|
||||
el.style.maxHeight = `${from}px`;
|
||||
el.style.overflow = "hidden";
|
||||
const anim = el.animate(
|
||||
[{ maxHeight: `${from}px` }, { maxHeight: `${to}px` }],
|
||||
{
|
||||
duration: 350,
|
||||
easing: "cubic-bezier(0.2, 0, 0, 1)",
|
||||
fill: "forwards",
|
||||
}
|
||||
);
|
||||
animations.push(anim);
|
||||
anim.onfinish = () => {
|
||||
el.style.maxHeight = "";
|
||||
el.style.overflow = "";
|
||||
};
|
||||
});
|
||||
|
||||
setSelectionExiting(false);
|
||||
setPreferredIndex(null);
|
||||
|
||||
// Restore scroll after animations + React settle
|
||||
const restoreScroll = () => {
|
||||
requestAnimationFrame(() => {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollTop = scrollTop;
|
||||
scrollContainer.style.overflow = "";
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
if (animations.length > 0) {
|
||||
Promise.all(animations.map((a) => a.finished))
|
||||
.then(restoreScroll)
|
||||
.catch(restoreScroll);
|
||||
} else {
|
||||
restoreScroll();
|
||||
}
|
||||
|
||||
// Clear preferredResponseId in the local tree so input bar re-gates
|
||||
if (parentMessage && currentSessionId) {
|
||||
const tree = useChatSessionStore
|
||||
.getState()
|
||||
.sessions.get(currentSessionId)?.messageTree;
|
||||
if (tree) {
|
||||
const userMsg = tree.get(parentMessage.nodeId);
|
||||
if (userMsg) {
|
||||
const updated = new Map(tree);
|
||||
updated.set(parentMessage.nodeId, {
|
||||
...userMsg,
|
||||
preferredResponseId: undefined,
|
||||
});
|
||||
updateSessionMessageTree(currentSessionId, updated);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 450);
|
||||
}, [
|
||||
parentMessage,
|
||||
currentSessionId,
|
||||
updateSessionMessageTree,
|
||||
preferredIndex,
|
||||
hiddenPanels,
|
||||
]);
|
||||
|
||||
// Clear preferred selection when generation starts
|
||||
// Reset selection state when generation restarts
|
||||
useEffect(() => {
|
||||
if (isGenerating) {
|
||||
setPreferredIndex(null);
|
||||
setHasEnteredSelection(false);
|
||||
setSelectionExiting(false);
|
||||
}
|
||||
}, [isGenerating]);
|
||||
|
||||
@@ -352,39 +204,22 @@ export default function MultiModelResponseView({
|
||||
(r) => r.modelIndex === preferredIndex
|
||||
);
|
||||
|
||||
// Track whether selection mode was ever entered — once it has been,
|
||||
// we stay in the selection layout (even after deselect) to avoid a
|
||||
// jarring DOM swap between the two layout strategies.
|
||||
const [hasEnteredSelection, setHasEnteredSelection] = useState(
|
||||
() => preferredIndex !== null
|
||||
);
|
||||
|
||||
const isActivelySelected =
|
||||
// Selection mode when preferred is set, found in responses, not generating, and at least 2 visible panels
|
||||
const showSelectionMode =
|
||||
preferredIndex !== null &&
|
||||
preferredIdx !== -1 &&
|
||||
!isGenerating &&
|
||||
visibleResponses.length > 1;
|
||||
|
||||
// Trigger the slide-out animation one frame after entering selection mode
|
||||
useEffect(() => {
|
||||
if (isActivelySelected) setHasEnteredSelection(true);
|
||||
}, [isActivelySelected]);
|
||||
|
||||
// Use the selection layout once a preferred response has been chosen,
|
||||
// even after deselect. Only fall through to generation layout before
|
||||
// the first selection or during active streaming.
|
||||
const showSelectionMode = isActivelySelected || hasEnteredSelection;
|
||||
|
||||
// Trigger the slide-out animation one frame after a preferred panel is selected.
|
||||
// Uses isActivelySelected (not showSelectionMode) so re-selecting after a
|
||||
// deselect still triggers the animation.
|
||||
useEffect(() => {
|
||||
if (!isActivelySelected) {
|
||||
// Don't reset selectionEntered here — handleDeselectPreferred manages it
|
||||
if (!showSelectionMode) {
|
||||
setSelectionEntered(false);
|
||||
return;
|
||||
}
|
||||
const raf = requestAnimationFrame(() => setSelectionEntered(true));
|
||||
return () => cancelAnimationFrame(raf);
|
||||
}, [isActivelySelected]);
|
||||
}, [showSelectionMode]);
|
||||
|
||||
// Build panel props — isHidden reflects actual hidden state
|
||||
const buildPanelProps = useCallback(
|
||||
@@ -396,7 +231,6 @@ export default function MultiModelResponseView({
|
||||
isHidden: hiddenPanels.has(response.modelIndex),
|
||||
isNonPreferredInSelection: isNonPreferred,
|
||||
onSelect: () => handleSelectPreferred(response.modelIndex),
|
||||
onDeselect: handleDeselectPreferred,
|
||||
onToggleVisibility: () => toggleVisibility(response.modelIndex),
|
||||
agentMessageProps: {
|
||||
rawPackets: response.packets,
|
||||
@@ -421,7 +255,6 @@ export default function MultiModelResponseView({
|
||||
preferredIndex,
|
||||
hiddenPanels,
|
||||
handleSelectPreferred,
|
||||
handleDeselectPreferred,
|
||||
toggleVisibility,
|
||||
chatState,
|
||||
llmManager,
|
||||
@@ -477,30 +310,25 @@ export default function MultiModelResponseView({
|
||||
<div
|
||||
ref={trackContainerRef}
|
||||
className="w-full overflow-hidden"
|
||||
style={
|
||||
isActivelySelected
|
||||
? {
|
||||
maskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
WebkitMaskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
style={{
|
||||
maskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
WebkitMaskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="flex items-start"
|
||||
style={{
|
||||
gap: `${PANEL_GAP}px`,
|
||||
transition:
|
||||
selectionEntered || selectionExiting
|
||||
? "transform 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
transition: selectionEntered
|
||||
? "transform 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
transform: trackTransform,
|
||||
}}
|
||||
>
|
||||
{responses.map((r, i) => {
|
||||
const isHidden = hiddenPanels.has(r.modelIndex);
|
||||
const isPref = r.modelIndex === preferredIndex;
|
||||
const isNonPref = !isHidden && !isPref && preferredIndex !== null;
|
||||
const isNonPref = !isHidden && !isPref;
|
||||
const finalW = selectionWidths[i]!;
|
||||
const startW = isHidden ? HIDDEN_PANEL_W : SELECTION_PANEL_W;
|
||||
const capped = isNonPref && preferredPanelHeight != null;
|
||||
@@ -509,11 +337,6 @@ export default function MultiModelResponseView({
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
ref={(el) => {
|
||||
if (el) {
|
||||
panelElsRef.current.set(r.modelIndex, el);
|
||||
} else {
|
||||
panelElsRef.current.delete(r.modelIndex);
|
||||
}
|
||||
if (isPref) preferredPanelRef(el);
|
||||
if (capped && el) {
|
||||
const doesOverflow = el.scrollHeight > el.clientHeight;
|
||||
@@ -530,10 +353,9 @@ export default function MultiModelResponseView({
|
||||
style={{
|
||||
width: `${selectionEntered ? finalW : startW}px`,
|
||||
flexShrink: 0,
|
||||
transition:
|
||||
selectionEntered || selectionExiting
|
||||
? "width 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
transition: selectionEntered
|
||||
? "width 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
maxHeight: capped ? preferredPanelHeight : undefined,
|
||||
overflow: capped ? "hidden" : undefined,
|
||||
position: capped ? "relative" : undefined,
|
||||
@@ -566,7 +388,7 @@ export default function MultiModelResponseView({
|
||||
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<div className="flex gap-6 items-start justify-center w-full">
|
||||
<div className="flex gap-6 items-start w-full">
|
||||
{responses.map((r) => {
|
||||
const isHidden = hiddenPanels.has(r.modelIndex);
|
||||
return (
|
||||
|
||||
@@ -1,25 +1,3 @@
|
||||
/* Map Tailwind Typography prose variables to the project's color tokens.
|
||||
These auto-switch for dark mode via colors.css — no dark: modifier needed.
|
||||
Note: text-05 = highest contrast, text-01 = lowest. */
|
||||
.prose-onyx {
|
||||
--tw-prose-body: var(--text-05);
|
||||
--tw-prose-headings: var(--text-05);
|
||||
--tw-prose-lead: var(--text-04);
|
||||
--tw-prose-links: var(--action-link-05);
|
||||
--tw-prose-bold: var(--text-05);
|
||||
--tw-prose-counters: var(--text-03);
|
||||
--tw-prose-bullets: var(--text-03);
|
||||
--tw-prose-hr: var(--border-02);
|
||||
--tw-prose-quotes: var(--text-04);
|
||||
--tw-prose-quote-borders: var(--border-02);
|
||||
--tw-prose-captions: var(--text-03);
|
||||
--tw-prose-code: var(--text-05);
|
||||
--tw-prose-pre-code: var(--text-04);
|
||||
--tw-prose-pre-bg: var(--background-code-01);
|
||||
--tw-prose-th-borders: var(--border-02);
|
||||
--tw-prose-td-borders: var(--border-01);
|
||||
}
|
||||
|
||||
/* Light mode syntax highlighting (Atom One Light) */
|
||||
.hljs {
|
||||
color: #383a42 !important;
|
||||
@@ -258,102 +236,23 @@ pre[class*="language-"] {
|
||||
scrollbar-color: #4b5563 #1f2937;
|
||||
}
|
||||
|
||||
/* Card wrapper — holds the background, border-radius, padding, and fade overlay.
|
||||
Does NOT scroll — the inner .markdown-table-breakout handles that. */
|
||||
.markdown-table-card {
|
||||
position: relative;
|
||||
background: var(--background-neutral-01);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Scrollable table container — sits inside the card.
|
||||
* Table breakout container - allows tables to extend beyond their parent's
|
||||
* constrained width to use the full container query width (100cqw).
|
||||
*
|
||||
* Requires an ancestor element with `container-type: inline-size` (@container in Tailwind).
|
||||
*
|
||||
* How the math works:
|
||||
* - width: 100cqw → expand to full container query width
|
||||
* - marginLeft: calc((100% - 100cqw) / 2) → negative margin pulls element left
|
||||
* (100% is parent width, 100cqw is larger, so result is negative)
|
||||
* - paddingLeft/Right: calc((100cqw - 100%) / 2) → padding keeps content aligned
|
||||
* with original position while allowing scroll area to extend
|
||||
*/
|
||||
.markdown-table-breakout {
|
||||
overflow-x: auto;
|
||||
|
||||
/* Always reserve scrollbar height so hover doesn't shift content.
|
||||
Thumb is transparent by default, revealed on hover. */
|
||||
scrollbar-width: thin; /* Firefox — always shows track */
|
||||
scrollbar-color: transparent transparent; /* invisible thumb + track */
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar {
|
||||
height: 6px;
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar-thumb {
|
||||
background: transparent;
|
||||
border-radius: 3px;
|
||||
}
|
||||
.markdown-table-breakout:hover {
|
||||
scrollbar-color: var(--border-03) transparent; /* Firefox — reveal thumb */
|
||||
}
|
||||
.markdown-table-breakout:hover::-webkit-scrollbar-thumb {
|
||||
background: var(--border-03);
|
||||
}
|
||||
|
||||
/* Fade the right edge via an ::after overlay on the non-scrolling card.
|
||||
Stays pinned while table scrolls; doesn't affect the sticky column. */
|
||||
.markdown-table-card::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
width: 2rem;
|
||||
pointer-events: none;
|
||||
z-index: 2;
|
||||
background: linear-gradient(
|
||||
to right,
|
||||
transparent,
|
||||
var(--background-neutral-01)
|
||||
);
|
||||
border-radius: 0 0.5rem 0.5rem 0;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.markdown-table-card[data-overflows="true"]::after {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Sticky first column — inherits the container's background so it
|
||||
matches regardless of theme or custom wallpaper. */
|
||||
.markdown-table-breakout th:first-child,
|
||||
.markdown-table-breakout td:first-child {
|
||||
position: sticky;
|
||||
left: 0;
|
||||
z-index: 1;
|
||||
padding-left: 0.75rem;
|
||||
background: var(--background-neutral-01);
|
||||
}
|
||||
.markdown-table-breakout th:last-child,
|
||||
.markdown-table-breakout td:last-child {
|
||||
padding-right: 0.75rem;
|
||||
}
|
||||
|
||||
/* Shadow on sticky column when scrolled. Uses an ::after pseudo-element
|
||||
so it isn't clipped by the overflow container or the mask-image fade. */
|
||||
.markdown-table-breakout th:first-child::after,
|
||||
.markdown-table-breakout td:first-child::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: -6px;
|
||||
bottom: 0;
|
||||
width: 6px;
|
||||
pointer-events: none;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s;
|
||||
box-shadow: inset 6px 0 8px -4px var(--alpha-grey-100-25);
|
||||
}
|
||||
.dark .markdown-table-breakout th:first-child::after,
|
||||
.dark .markdown-table-breakout td:first-child::after {
|
||||
box-shadow: inset 6px 0 8px -4px var(--alpha-grey-100-60);
|
||||
}
|
||||
.markdown-table-breakout[data-scrolled="true"] th:first-child::after,
|
||||
.markdown-table-breakout[data-scrolled="true"] td:first-child::after {
|
||||
opacity: 1;
|
||||
width: 100cqw;
|
||||
margin-left: calc((100% - 100cqw) / 2);
|
||||
padding-left: calc((100cqw - 100%) / 2);
|
||||
padding-right: calc((100cqw - 100%) / 2);
|
||||
}
|
||||
|
||||
@@ -51,8 +51,6 @@ export interface AgentMessageProps {
|
||||
processingDurationSeconds?: number;
|
||||
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
|
||||
hideFooter?: boolean;
|
||||
/** Skip TTS streaming (used in multi-model where voice doesn't apply) */
|
||||
disableTTS?: boolean;
|
||||
}
|
||||
|
||||
// TODO: Consider more robust comparisons:
|
||||
@@ -101,7 +99,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
parentMessage,
|
||||
processingDurationSeconds,
|
||||
hideFooter,
|
||||
disableTTS,
|
||||
}: AgentMessageProps) {
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
const finalAnswerRef = useRef<HTMLDivElement>(null);
|
||||
@@ -205,9 +202,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
// Skip if we've already finished TTS for this message
|
||||
if (ttsCompletedRef.current) return;
|
||||
|
||||
// Multi-model: skip TTS entirely
|
||||
if (disableTTS) return;
|
||||
|
||||
// If user cancelled generation, do not send more text to TTS.
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
ttsCompletedRef.current = true;
|
||||
@@ -311,7 +305,7 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onRenderComplete();
|
||||
}
|
||||
}}
|
||||
animate={!stopPacketSeen}
|
||||
animate={false}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useCallback, useEffect, useRef, useMemo, JSX } from "react";
|
||||
import React, { useCallback, useMemo, JSX } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
@@ -17,66 +17,6 @@ import { transformLinkUri, cn } from "@/lib/utils";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
|
||||
|
||||
/** Table wrapper that detects horizontal overflow and shows a fade + scrollbar. */
|
||||
interface ScrollableTableProps
|
||||
extends React.TableHTMLAttributes<HTMLTableElement> {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export function ScrollableTable({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: ScrollableTableProps) {
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const wrapRef = useRef<HTMLDivElement>(null);
|
||||
const tableRef = useRef<HTMLTableElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const el = scrollRef.current;
|
||||
const wrap = wrapRef.current;
|
||||
const table = tableRef.current;
|
||||
if (!el || !wrap) return;
|
||||
|
||||
const check = () => {
|
||||
const overflows = el.scrollWidth > el.clientWidth;
|
||||
const atEnd = el.scrollLeft + el.clientWidth >= el.scrollWidth - 2;
|
||||
wrap.dataset.overflows = overflows && !atEnd ? "true" : "false";
|
||||
el.dataset.scrolled = el.scrollLeft > 0 ? "true" : "false";
|
||||
};
|
||||
|
||||
check();
|
||||
el.addEventListener("scroll", check, { passive: true });
|
||||
// Observe both the scroll container (parent resize) and the table
|
||||
// itself (content growth during streaming).
|
||||
const ro = new ResizeObserver(check);
|
||||
ro.observe(el);
|
||||
if (table) ro.observe(table);
|
||||
|
||||
return () => {
|
||||
el.removeEventListener("scroll", check);
|
||||
ro.disconnect();
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div ref={wrapRef} className="markdown-table-card">
|
||||
<div ref={scrollRef} className="markdown-table-breakout">
|
||||
<table
|
||||
ref={tableRef}
|
||||
className={cn(
|
||||
className,
|
||||
"min-w-full !my-0 [&_th]:whitespace-nowrap [&_td]:whitespace-nowrap"
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes content for markdown rendering by handling code blocks and LaTeX
|
||||
*/
|
||||
@@ -187,9 +127,11 @@ export const useMarkdownComponents = (
|
||||
},
|
||||
table: ({ node, className, children, ...props }: any) => {
|
||||
return (
|
||||
<ScrollableTable className={className} {...props}>
|
||||
{children}
|
||||
</ScrollableTable>
|
||||
<div className="markdown-table-breakout">
|
||||
<table className={cn(className, "min-w-full")} {...props}>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
code: ({ node, className, children }: any) => {
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||
import ReactMarkdown, { Components } from "react-markdown";
|
||||
import type { PluggableList } from "unified";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeHighlight from "rehype-highlight";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
|
||||
import { useTypewriter } from "@/hooks/useTypewriter";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
import {
|
||||
ChatPacket,
|
||||
PacketType,
|
||||
@@ -16,22 +8,16 @@ import {
|
||||
} from "../../../services/streamingModels";
|
||||
import { MessageRenderer, FullChatState } from "../interfaces";
|
||||
import { isFinalAnswerComplete } from "../../../services/packetUtils";
|
||||
import { processContent } from "../markdownUtils";
|
||||
import { useMarkdownRenderer } from "../markdownUtils";
|
||||
import { BlinkingBar } from "../../BlinkingBar";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import {
|
||||
MemoizedAnchor,
|
||||
MemoizedParagraph,
|
||||
} from "@/app/app/message/MemoizedTextComponents";
|
||||
import { extractCodeText } from "@/app/app/message/codeUtils";
|
||||
import { CodeBlock } from "@/app/app/message/CodeBlock";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
|
||||
import { cn, transformLinkUri } from "@/lib/utils";
|
||||
|
||||
/** Maps a visible-char count to a markdown index (skips formatting chars,
|
||||
* extends to word boundary). Used by the voice-sync reveal path only. */
|
||||
/**
|
||||
* Maps a cleaned character position to the corresponding position in markdown text.
|
||||
* This allows progressive reveal to work with markdown formatting.
|
||||
*/
|
||||
function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
// Skip patterns that don't contribute to visible character count
|
||||
const skipChars = new Set(["*", "`", "#"]);
|
||||
let cleanIndex = 0;
|
||||
let mdIndex = 0;
|
||||
@@ -39,11 +25,13 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
while (cleanIndex < cleanChars && mdIndex < markdown.length) {
|
||||
const char = markdown[mdIndex];
|
||||
|
||||
// Skip markdown formatting characters
|
||||
if (char !== undefined && skipChars.has(char)) {
|
||||
mdIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle link syntax [text](url) - skip the (url) part but count the text
|
||||
if (
|
||||
char === "]" &&
|
||||
mdIndex + 1 < markdown.length &&
|
||||
@@ -60,6 +48,7 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
// Extend to word boundary to avoid cutting mid-word
|
||||
while (
|
||||
mdIndex < markdown.length &&
|
||||
markdown[mdIndex] !== " " &&
|
||||
@@ -71,15 +60,8 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
return mdIndex;
|
||||
}
|
||||
|
||||
// Cheap streaming plugins (gfm only) → cheap per-frame parse. Full
|
||||
// pipeline flips in once, at the end, for syntax highlighting + math.
|
||||
const STREAMING_REMARK_PLUGINS: PluggableList = [remarkGfm];
|
||||
const STREAMING_REHYPE_PLUGINS: PluggableList = [];
|
||||
const FULL_REMARK_PLUGINS: PluggableList = [
|
||||
remarkGfm,
|
||||
[remarkMath, { singleDollarTextMath: true }],
|
||||
];
|
||||
const FULL_REHYPE_PLUGINS: PluggableList = [rehypeHighlight, rehypeKatex];
|
||||
// Control the rate of packet streaming (packets per second)
|
||||
const PACKET_DELAY_MS = 10;
|
||||
|
||||
export const MessageTextRenderer: MessageRenderer<
|
||||
ChatPacket,
|
||||
@@ -96,17 +78,19 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
stopReason,
|
||||
children,
|
||||
}) => {
|
||||
// If we're animating and the final answer is already complete, show more packets initially
|
||||
const initialPacketCount = animate
|
||||
? packets.length > 0
|
||||
? 1 // Otherwise start with 1 packet
|
||||
: 0
|
||||
: -1; // Show all if not animating
|
||||
|
||||
const [displayedPacketCount, setDisplayedPacketCount] =
|
||||
useState(initialPacketCount);
|
||||
const lastStableSyncedContentRef = useRef("");
|
||||
const lastVisibleContentRef = useRef("");
|
||||
|
||||
// Timeout guard: if TTS doesn't start within 5s of voice sync
|
||||
// activating, fall back to normal streaming. Prevents permanent
|
||||
// content suppression when the voice WebSocket fails to connect.
|
||||
const [voiceSyncTimedOut, setVoiceSyncTimedOut] = useState(false);
|
||||
const voiceSyncTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
|
||||
null
|
||||
);
|
||||
|
||||
// Get voice mode context for progressive text reveal synced with audio
|
||||
const {
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
@@ -115,6 +99,7 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
isAwaitingAutoPlaybackStart,
|
||||
} = useVoiceMode();
|
||||
|
||||
// Get the full content from all packets
|
||||
const fullContent = packets
|
||||
.map((packet) => {
|
||||
if (
|
||||
@@ -129,74 +114,117 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
|
||||
const shouldUseAutoPlaybackSync =
|
||||
autoPlayback &&
|
||||
!voiceSyncTimedOut &&
|
||||
typeof messageNodeId === "number" &&
|
||||
activeMessageNodeId === messageNodeId;
|
||||
|
||||
// Start/clear the timeout when voice sync activates/deactivates.
|
||||
// Animation effect - gradually increase displayed packets at controlled rate
|
||||
useEffect(() => {
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
if (!voiceSyncTimeoutRef.current) {
|
||||
voiceSyncTimeoutRef.current = setTimeout(() => {
|
||||
setVoiceSyncTimedOut(true);
|
||||
}, 5000);
|
||||
}
|
||||
} else {
|
||||
// TTS started or sync deactivated — clear timeout
|
||||
if (voiceSyncTimeoutRef.current) {
|
||||
clearTimeout(voiceSyncTimeoutRef.current);
|
||||
voiceSyncTimeoutRef.current = null;
|
||||
}
|
||||
if (voiceSyncTimedOut && !autoPlayback) setVoiceSyncTimedOut(false);
|
||||
if (!animate) {
|
||||
setDisplayedPacketCount(-1); // Show all packets
|
||||
return;
|
||||
}
|
||||
return () => {
|
||||
if (voiceSyncTimeoutRef.current) {
|
||||
clearTimeout(voiceSyncTimeoutRef.current);
|
||||
voiceSyncTimeoutRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [
|
||||
shouldUseAutoPlaybackSync,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
isAudioSyncActive,
|
||||
voiceSyncTimedOut,
|
||||
]);
|
||||
|
||||
// Normal streaming hands full text to the typewriter. Voice-sync
|
||||
// paths pre-slice and bypass. If shouldUseAutoPlaybackSync is false
|
||||
// (including after the 5s timeout), all paths fall through to fullContent.
|
||||
if (displayedPacketCount >= 0 && displayedPacketCount < packets.length) {
|
||||
const timer = setTimeout(() => {
|
||||
setDisplayedPacketCount((prev) => Math.min(prev + 1, packets.length));
|
||||
}, PACKET_DELAY_MS);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [animate, displayedPacketCount, packets.length]);
|
||||
|
||||
// Reset displayed count when packet array changes significantly (e.g., new message)
|
||||
useEffect(() => {
|
||||
if (animate && packets.length < displayedPacketCount) {
|
||||
const resetCount = isFinalAnswerComplete(packets)
|
||||
? Math.min(10, packets.length)
|
||||
: packets.length > 0
|
||||
? 1
|
||||
: 0;
|
||||
setDisplayedPacketCount(resetCount);
|
||||
}
|
||||
}, [animate, packets.length, displayedPacketCount]);
|
||||
|
||||
// Only mark as complete when all packets are received AND displayed
|
||||
useEffect(() => {
|
||||
if (isFinalAnswerComplete(packets)) {
|
||||
// If animating, wait until all packets are displayed
|
||||
if (
|
||||
animate &&
|
||||
displayedPacketCount >= 0 &&
|
||||
displayedPacketCount < packets.length
|
||||
) {
|
||||
return;
|
||||
}
|
||||
onComplete();
|
||||
}
|
||||
}, [packets, onComplete, animate, displayedPacketCount]);
|
||||
|
||||
// Get content based on displayed packet count or audio progress
|
||||
const computedContent = useMemo(() => {
|
||||
// Hold response in "thinking" state only while autoplay startup is pending.
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Sync text with audio only for the message currently being spoken.
|
||||
if (shouldUseAutoPlaybackSync && isAudioSyncActive) {
|
||||
const MIN_REVEAL_CHARS = 12;
|
||||
if (revealedCharCount < MIN_REVEAL_CHARS) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Reveal text progressively based on audio progress
|
||||
const revealPos = getRevealPosition(fullContent, revealedCharCount);
|
||||
return fullContent.slice(0, Math.max(revealPos, 0));
|
||||
}
|
||||
|
||||
// During an active synced turn, if sync temporarily drops, keep current reveal
|
||||
// instead of jumping to full content or blanking.
|
||||
if (shouldUseAutoPlaybackSync && !stopPacketSeen) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
return fullContent;
|
||||
// Standard behavior when auto-playback is off
|
||||
if (!animate || displayedPacketCount === -1) {
|
||||
return fullContent; // Show all content
|
||||
}
|
||||
|
||||
// Packet-based reveal (when auto-playback is disabled)
|
||||
return packets
|
||||
.slice(0, displayedPacketCount)
|
||||
.map((packet) => {
|
||||
if (
|
||||
packet.obj.type === PacketType.MESSAGE_DELTA ||
|
||||
packet.obj.type === PacketType.MESSAGE_START
|
||||
) {
|
||||
return packet.obj.content;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join("");
|
||||
}, [
|
||||
shouldUseAutoPlaybackSync,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
isAudioSyncActive,
|
||||
revealedCharCount,
|
||||
animate,
|
||||
displayedPacketCount,
|
||||
fullContent,
|
||||
packets,
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
messageNodeId,
|
||||
shouldUseAutoPlaybackSync,
|
||||
stopPacketSeen,
|
||||
]);
|
||||
|
||||
// Monotonic guard for voice sync + freeze on user cancel.
|
||||
// Keep synced text monotonic: once visible, never regress or disappear between chunks.
|
||||
const content = useMemo(() => {
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
// On user cancel during live streaming, freeze at exactly what was already
|
||||
// visible to prevent flicker. On history reload (animate=false), the ref
|
||||
// starts empty so we must use computedContent directly.
|
||||
if (wasUserCancelled && animate) {
|
||||
return lastVisibleContentRef.current;
|
||||
}
|
||||
@@ -214,10 +242,13 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
// If content shape changed unexpectedly mid-stream, prefer the stable version
|
||||
// to avoid flicker/dumps.
|
||||
if (!stopPacketSeen || wasUserCancelled) {
|
||||
return last;
|
||||
}
|
||||
|
||||
// For normal completed responses, allow final full content.
|
||||
return computedContent;
|
||||
}, [
|
||||
computedContent,
|
||||
@@ -227,6 +258,7 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
animate,
|
||||
]);
|
||||
|
||||
// Sync the stable ref outside of useMemo to avoid side effects during render.
|
||||
useEffect(() => {
|
||||
if (stopReason === StopReason.USER_CANCELLED) {
|
||||
return;
|
||||
@@ -238,128 +270,13 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
}
|
||||
}, [content, shouldUseAutoPlaybackSync, stopReason]);
|
||||
|
||||
// Track last actually rendered content so cancel can freeze without dumping buffered text.
|
||||
useEffect(() => {
|
||||
if (content.length > 0) {
|
||||
lastVisibleContentRef.current = content;
|
||||
}
|
||||
}, [content]);
|
||||
|
||||
const isStreamingAnimationEnabled =
|
||||
animate &&
|
||||
!shouldUseAutoPlaybackSync &&
|
||||
stopReason !== StopReason.USER_CANCELLED;
|
||||
|
||||
const isStreamFinished = isFinalAnswerComplete(packets);
|
||||
|
||||
const displayedContent = useTypewriter(content, isStreamingAnimationEnabled);
|
||||
|
||||
// One-way signal: stream done AND typewriter caught up. Do NOT derive
|
||||
// this from "typewriter currently behind" — it oscillates mid-stream
|
||||
// between packet bursts and would thrash the plugin pipeline.
|
||||
const streamFullyDisplayed =
|
||||
isStreamFinished && displayedContent.length >= content.length;
|
||||
|
||||
// Fire onComplete exactly once per mount. `onComplete` is an inline
|
||||
// arrow in AgentMessage so its identity changes on every parent render;
|
||||
// without this guard, each new identity would re-fire the effect once
|
||||
// `streamFullyDisplayed` is true.
|
||||
const onCompleteFiredRef = useRef(false);
|
||||
useEffect(() => {
|
||||
if (streamFullyDisplayed && !onCompleteFiredRef.current) {
|
||||
onCompleteFiredRef.current = true;
|
||||
onComplete();
|
||||
}
|
||||
}, [streamFullyDisplayed, onComplete]);
|
||||
|
||||
const processedContent = useMemo(
|
||||
() => processContent(displayedContent),
|
||||
[displayedContent]
|
||||
);
|
||||
|
||||
// Stable-identity components for ReactMarkdown. Dynamic data (`state`,
|
||||
// `processedContent`) flows through refs so the callback identities
|
||||
// never change — otherwise every typewriter tick would invalidate
|
||||
// React reconciliation on the markdown subtree.
|
||||
const stateRef = useRef(state);
|
||||
stateRef.current = state;
|
||||
const processedContentRef = useRef(processedContent);
|
||||
processedContentRef.current = processedContent;
|
||||
|
||||
const markdownComponents = useMemo<Components>(
|
||||
() => ({
|
||||
a: ({ href, children }) => {
|
||||
const s = stateRef.current;
|
||||
const imageFileId = extractChatImageFileId(
|
||||
href,
|
||||
String(children ?? "")
|
||||
);
|
||||
if (imageFileId) {
|
||||
return (
|
||||
<InMessageImage
|
||||
fileId={imageFileId}
|
||||
fileName={String(children ?? "")}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<MemoizedAnchor
|
||||
updatePresentingDocument={s?.setPresentingDocument || (() => {})}
|
||||
docs={s?.docs || []}
|
||||
userFiles={s?.userFiles || []}
|
||||
citations={s?.citations}
|
||||
href={href}
|
||||
>
|
||||
{children}
|
||||
</MemoizedAnchor>
|
||||
);
|
||||
},
|
||||
p: ({ children }) => (
|
||||
<MemoizedParagraph className="font-main-content-body">
|
||||
{children}
|
||||
</MemoizedParagraph>
|
||||
),
|
||||
pre: ({ children }) => <>{children}</>,
|
||||
b: ({ className, children }) => (
|
||||
<span className={className}>{children}</span>
|
||||
),
|
||||
ul: ({ className, children, ...rest }) => (
|
||||
<ul className={className} {...rest}>
|
||||
{children}
|
||||
</ul>
|
||||
),
|
||||
ol: ({ className, children, ...rest }) => (
|
||||
<ol className={className} {...rest}>
|
||||
{children}
|
||||
</ol>
|
||||
),
|
||||
li: ({ className, children, ...rest }) => (
|
||||
<li className={className} {...rest}>
|
||||
{children}
|
||||
</li>
|
||||
),
|
||||
table: ({ className, children, ...rest }) => (
|
||||
<div className="markdown-table-breakout">
|
||||
<table className={cn(className, "min-w-full")} {...rest}>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
),
|
||||
code: ({ node, className, children }) => {
|
||||
const codeText = extractCodeText(
|
||||
node,
|
||||
processedContentRef.current,
|
||||
children
|
||||
);
|
||||
return (
|
||||
<CodeBlock className={className} codeText={codeText}>
|
||||
{children}
|
||||
</CodeBlock>
|
||||
);
|
||||
},
|
||||
}),
|
||||
[]
|
||||
);
|
||||
|
||||
const shouldShowThinkingPlaceholder =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
isAwaitingAutoPlaybackStart &&
|
||||
@@ -375,16 +292,16 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowCursor =
|
||||
displayedContent.length > 0 &&
|
||||
((isStreamingAnimationEnabled && !streamFullyDisplayed) ||
|
||||
(!isStreamingAnimationEnabled && !stopPacketSeen) ||
|
||||
content.length > 0 &&
|
||||
(!stopPacketSeen ||
|
||||
(shouldUseAutoPlaybackSync && content.length < fullContent.length));
|
||||
|
||||
// `[*]() ` is rendered by the anchor component as an inline blinking
|
||||
// caret, keeping it flush with the trailing character.
|
||||
const markdownInput = shouldShowCursor
|
||||
? processedContent + " [*]() "
|
||||
: processedContent;
|
||||
const { renderedContent } = useMarkdownRenderer(
|
||||
// the [*]() is a hack to show a blinking dot when the packet is not complete
|
||||
shouldShowCursor ? content + " [*]() " : content,
|
||||
state,
|
||||
"font-main-content-body"
|
||||
);
|
||||
|
||||
return children([
|
||||
{
|
||||
@@ -395,26 +312,8 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
<Text as="span" secondaryBody text04 className="italic">
|
||||
Thinking
|
||||
</Text>
|
||||
) : displayedContent.length > 0 ? (
|
||||
<div dir="auto">
|
||||
<ReactMarkdown
|
||||
className="prose prose-onyx font-main-content-body max-w-full"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={
|
||||
streamFullyDisplayed
|
||||
? FULL_REMARK_PLUGINS
|
||||
: STREAMING_REMARK_PLUGINS
|
||||
}
|
||||
rehypePlugins={
|
||||
streamFullyDisplayed
|
||||
? FULL_REHYPE_PLUGINS
|
||||
: STREAMING_REHYPE_PLUGINS
|
||||
}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{markdownInput}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
) : content.length > 0 ? (
|
||||
<>{renderedContent}</>
|
||||
) : (
|
||||
<BlinkingBar addMargin />
|
||||
),
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
isRecommendedModel,
|
||||
} from "@/app/craft/onboarding/constants";
|
||||
import { ToggleWarningModal } from "./ToggleWarningModal";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import {
|
||||
Accordion,
|
||||
|
||||
@@ -48,7 +48,7 @@ import NotAllowedModal from "@/app/craft/onboarding/components/NotAllowedModal";
|
||||
import { useOnboarding } from "@/app/craft/onboarding/BuildOnboardingProvider";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import {
|
||||
getBuildUserPersona,
|
||||
getPersonaInfo,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
:root {
|
||||
--app-page-main-content-width: 45rem;
|
||||
--app-page-main-content-width: 52.5rem;
|
||||
--block-width-form-input-min: 10rem;
|
||||
|
||||
--container-sm: 42rem;
|
||||
|
||||
@@ -45,9 +45,6 @@ import { personaIncludesRetrieval } from "@/app/app/services/lib";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { eeGated } from "@/ce";
|
||||
import EESearchUI from "@/ee/sections/SearchUI";
|
||||
import useMultiModelChat from "@/hooks/useMultiModelChat";
|
||||
import ModelSelector from "@/refresh-components/popovers/ModelSelector";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
const SearchUI = eeGated(EESearchUI);
|
||||
|
||||
@@ -108,20 +105,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
// If no LLM provider is configured (e.g., fresh signup), the input bar is
|
||||
// disabled and a "Set up an LLM" button is shown (see bottom of component).
|
||||
const llmManager = useLlmManager(undefined, liveAgent ?? undefined);
|
||||
const multiModel = useMultiModelChat(llmManager);
|
||||
|
||||
// Sync single-model selection to llmManager so the submission path
|
||||
// uses the correct provider/version (mirrors AppPage behaviour).
|
||||
useEffect(() => {
|
||||
if (multiModel.selectedModels.length === 1) {
|
||||
const model = multiModel.selectedModels[0]!;
|
||||
llmManager.updateCurrentLlm({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
modelName: model.modelName,
|
||||
});
|
||||
}
|
||||
}, [multiModel.selectedModels]);
|
||||
|
||||
// Deep research toggle
|
||||
const { deepResearchEnabled, toggleDeepResearch } = useDeepResearchToggle({
|
||||
@@ -312,17 +295,12 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
|
||||
// If we already have messages (chat session started), always use chat mode
|
||||
// (matches AppPage behavior where existing sessions bypass classification)
|
||||
const selectedModels = multiModel.isMultiModelActive
|
||||
? multiModel.selectedModels
|
||||
: undefined;
|
||||
|
||||
if (hasMessages) {
|
||||
onSubmit({
|
||||
message: submittedMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabled,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
});
|
||||
return;
|
||||
}
|
||||
@@ -332,9 +310,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onSubmit({
|
||||
message: chatMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabled,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -351,8 +328,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
submitQuery,
|
||||
tabReadingEnabled,
|
||||
currentTabUrl,
|
||||
multiModel.isMultiModelActive,
|
||||
multiModel.selectedModels,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -370,16 +345,10 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onSubmit({
|
||||
message: lastUserMsg.message,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabled,
|
||||
messageIdToResend: lastUserMsg.messageId,
|
||||
});
|
||||
}, [
|
||||
messageHistory,
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabled,
|
||||
multiModel.isMultiModelActive,
|
||||
]);
|
||||
}, [messageHistory, onSubmit, currentMessageFiles, deepResearchEnabled]);
|
||||
|
||||
// Start a new chat session in the side panel
|
||||
const handleNewChat = useCallback(() => {
|
||||
@@ -487,7 +456,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onResubmit={handleResubmitLastMessage}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
anchorNodeId={anchorNodeId}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
/>
|
||||
</ChatScrollContainer>
|
||||
</>
|
||||
@@ -496,23 +464,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
{/* Welcome message - centered when no messages and not in search mode */}
|
||||
{!hasMessages && !isSearch && (
|
||||
<div className="relative w-full flex-1 flex flex-col items-center justify-end">
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="end"
|
||||
className="max-w-[var(--app-page-main-content-width)]"
|
||||
>
|
||||
<WelcomeMessage isDefaultAgent />
|
||||
{liveAgent && !llmManager.isLoadingProviders && (
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
)}
|
||||
</Section>
|
||||
<WelcomeMessage isDefaultAgent />
|
||||
<Spacer rem={1.5} />
|
||||
</div>
|
||||
)}
|
||||
@@ -522,25 +474,14 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"w-full flex flex-col",
|
||||
!isSidePanel && "max-w-[var(--app-page-main-content-width)]"
|
||||
!isSidePanel &&
|
||||
"max-w-[var(--app-page-main-content-width)] px-4"
|
||||
)}
|
||||
>
|
||||
{hasMessages && liveAgent && !llmManager.isLoadingProviders && (
|
||||
<div className="pb-1">
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<AppInputBar
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
toggleDeepResearch={toggleDeepResearch}
|
||||
isMultiModelActive={multiModel.isMultiModelActive}
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
initialMessage={message}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { useMemo } from "react";
|
||||
import { parseLlmDescriptor, structureValue } from "@/lib/llmConfig/utils";
|
||||
import { DefaultModel, LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { createIcon } from "@/components/icons/icons";
|
||||
|
||||
|
||||
@@ -644,7 +644,6 @@ export default function useChatController({
|
||||
});
|
||||
node.modelDisplayName = model.displayName;
|
||||
node.overridden_model = model.modelName;
|
||||
node.is_generating = true;
|
||||
return node;
|
||||
});
|
||||
}
|
||||
@@ -712,13 +711,6 @@ export default function useChatController({
|
||||
? selectedModels?.map((m) => m.displayName) ?? []
|
||||
: [];
|
||||
|
||||
// rAF-batched flush state. One Zustand write per frame instead of
|
||||
// one per packet.
|
||||
const dirtyModelIndices = new Set<number>();
|
||||
let singleModelDirty = false;
|
||||
let userNodeDirty = false;
|
||||
let pendingFlush = false;
|
||||
|
||||
/** Build a non-errored multi-model assistant node for upsert. */
|
||||
function buildAssistantNodeUpdate(
|
||||
idx: number,
|
||||
@@ -748,124 +740,16 @@ export default function useChatController({
|
||||
};
|
||||
}
|
||||
|
||||
/** With `onlyDirty`, rebuilds only those model nodes — unchanged
|
||||
* siblings keep their stable Message ref so React memo short-circuits. */
|
||||
function buildNonErroredNodes(
|
||||
overrides?: Partial<Message>,
|
||||
onlyDirty?: Set<number> | null
|
||||
): Message[] {
|
||||
/** Build updated nodes for all non-errored models. */
|
||||
function buildNonErroredNodes(overrides?: Partial<Message>): Message[] {
|
||||
const nodes: Message[] = [];
|
||||
for (let idx = 0; idx < initialAssistantNodes.length; idx++) {
|
||||
if (erroredModelIndices.has(idx)) continue;
|
||||
if (onlyDirty && !onlyDirty.has(idx)) continue;
|
||||
nodes.push(buildAssistantNodeUpdate(idx, overrides));
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
/** Flush accumulated packet state into the tree as one Zustand
|
||||
* update. No-op when nothing is pending. */
|
||||
function flushPendingUpdates() {
|
||||
if (!pendingFlush) return;
|
||||
pendingFlush = false;
|
||||
|
||||
parentMessage =
|
||||
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
|
||||
|
||||
let messagesToUpsert: Message[];
|
||||
|
||||
if (isMultiModel) {
|
||||
if (dirtyModelIndices.size === 0 && !userNodeDirty) return;
|
||||
|
||||
const dirtySnapshot = new Set(dirtyModelIndices);
|
||||
dirtyModelIndices.clear();
|
||||
const dirtyNodes = buildNonErroredNodes(undefined, dirtySnapshot);
|
||||
|
||||
if (userNodeDirty) {
|
||||
userNodeDirty = false;
|
||||
// Read current user node to preserve childrenNodeIds
|
||||
// (initialUserNode's are stale from creation time).
|
||||
const currentUserNode =
|
||||
currentMessageTreeLocal.get(initialUserNode.nodeId) ||
|
||||
initialUserNode;
|
||||
const updatedUserNode: Message = {
|
||||
...currentUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
};
|
||||
messagesToUpsert = [updatedUserNode, ...dirtyNodes];
|
||||
} else {
|
||||
messagesToUpsert = dirtyNodes;
|
||||
}
|
||||
|
||||
if (messagesToUpsert.length === 0) return;
|
||||
} else {
|
||||
if (!singleModelDirty) return;
|
||||
singleModelDirty = false;
|
||||
|
||||
messagesToUpsert = [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
},
|
||||
{
|
||||
...initialAgentNode,
|
||||
messageId: newAgentMessageId ?? undefined,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetCount: packets.length,
|
||||
processingDurationSeconds:
|
||||
finalMessage?.processing_duration_seconds ??
|
||||
(() => {
|
||||
const startTime = useChatSessionStore
|
||||
.getState()
|
||||
.getStreamingStartTime(frozenSessionId);
|
||||
return startTime
|
||||
? Math.floor((Date.now() - startTime) / 1000)
|
||||
: undefined;
|
||||
})(),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsert,
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
}
|
||||
|
||||
/** Awaits next animation frame (or a setTimeout fallback when the
|
||||
* tab is hidden — rAF is paused in background tabs, which would
|
||||
* otherwise hang the stream loop here), then flushes. Aligns
|
||||
* React updates with the paint cycle when visible. */
|
||||
function flushViaRAF(): Promise<void> {
|
||||
return new Promise<void>((resolve) => {
|
||||
let done = false;
|
||||
const flush = () => {
|
||||
if (done) return;
|
||||
done = true;
|
||||
flushPendingUpdates();
|
||||
resolve();
|
||||
};
|
||||
requestAnimationFrame(flush);
|
||||
// Fallback for hidden tabs where rAF is paused. Throttled to
|
||||
// ~1s by browsers, matching the previous setTimeout(500) cadence.
|
||||
setTimeout(flush, 100);
|
||||
});
|
||||
}
|
||||
|
||||
let streamSucceeded = false;
|
||||
|
||||
try {
|
||||
@@ -952,12 +836,7 @@ export default function useChatController({
|
||||
await delay(50);
|
||||
while (!stack.isComplete || !stack.isEmpty()) {
|
||||
if (stack.isEmpty()) {
|
||||
// Flush the burst on the next paint, or idle briefly.
|
||||
if (pendingFlush) {
|
||||
await flushViaRAF();
|
||||
} else {
|
||||
await delay(0.5);
|
||||
}
|
||||
await delay(0.5);
|
||||
}
|
||||
|
||||
if (!stack.isEmpty() && !controller.signal.aborted) {
|
||||
@@ -981,7 +860,6 @@ export default function useChatController({
|
||||
if ((packet as MessageResponseIDInfo).user_message_id) {
|
||||
newUserMessageId = (packet as MessageResponseIDInfo)
|
||||
.user_message_id;
|
||||
userNodeDirty = true;
|
||||
|
||||
// Track extension queries in PostHog (reuses isExtension/extensionContext from above)
|
||||
if (isExtension) {
|
||||
@@ -1020,8 +898,6 @@ export default function useChatController({
|
||||
modelDisplayNames[mi] = slot.model_name;
|
||||
}
|
||||
}
|
||||
userNodeDirty = true;
|
||||
pendingFlush = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1033,7 +909,6 @@ export default function useChatController({
|
||||
!files.some((existingFile) => existingFile.id === newFile.id)
|
||||
);
|
||||
files = files.concat(newUserFiles);
|
||||
if (newUserFiles.length > 0) userNodeDirty = true;
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "file_ids")) {
|
||||
@@ -1053,20 +928,15 @@ export default function useChatController({
|
||||
|
||||
// In multi-model mode, route per-model errors to the specific model's
|
||||
// node instead of killing the entire stream. Other models keep streaming.
|
||||
if (isMultiModel) {
|
||||
// Multi-model: isolate the error to its panel. Never throw
|
||||
// or set global error state — other models keep streaming.
|
||||
const errorModelIndex = streamingError.details?.model_index as
|
||||
| number
|
||||
| undefined;
|
||||
if (isMultiModel && streamingError.details?.model_index != null) {
|
||||
const errorModelIndex = streamingError.details
|
||||
.model_index as number;
|
||||
if (
|
||||
errorModelIndex != null &&
|
||||
errorModelIndex >= 0 &&
|
||||
errorModelIndex < initialAssistantNodes.length
|
||||
) {
|
||||
const errorNode = initialAssistantNodes[errorModelIndex]!;
|
||||
erroredModelIndices.add(errorModelIndex);
|
||||
dirtyModelIndices.delete(errorModelIndex);
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: [
|
||||
{
|
||||
@@ -1093,15 +963,8 @@ export default function useChatController({
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
} else {
|
||||
// Error without model_index in multi-model — can't route
|
||||
// to a specific panel. Log and continue; the stream loop
|
||||
// stays alive for other models.
|
||||
console.warn(
|
||||
"Multi-model error without model_index:",
|
||||
streamingError.error
|
||||
);
|
||||
}
|
||||
// Skip the normal per-packet upsert — we already upserted the error node
|
||||
continue;
|
||||
} else {
|
||||
// Single-model: kill the stream
|
||||
@@ -1130,21 +993,19 @@ export default function useChatController({
|
||||
|
||||
if (isMultiModel) {
|
||||
// Multi-model: route packet by placement.model_index.
|
||||
// OverallStop (type "stop") has model_index=null — it's a
|
||||
// global terminal packet that must be delivered to ALL
|
||||
// models so each panel's AgentMessage sees the stop and
|
||||
// exits "Thinking..." state.
|
||||
// OverallStop (type "stop") has model_index=null — it's a global
|
||||
// terminal packet that must be delivered to ALL models so each
|
||||
// panel's AgentMessage sees the stop and exits "Thinking..." state.
|
||||
const isGlobalStop =
|
||||
packetObj.type === "stop" &&
|
||||
typedPacket.placement?.model_index == null;
|
||||
|
||||
if (isGlobalStop) {
|
||||
for (let mi = 0; mi < packetsPerModel.length; mi++) {
|
||||
// Mutated in place — change detection uses packetCount, not array identity.
|
||||
packetsPerModel[mi]!.push(typedPacket);
|
||||
if (!erroredModelIndices.has(mi)) {
|
||||
dirtyModelIndices.add(mi);
|
||||
}
|
||||
packetsPerModel[mi] = [
|
||||
...packetsPerModel[mi]!,
|
||||
typedPacket,
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1154,10 +1015,10 @@ export default function useChatController({
|
||||
modelIndex >= 0 &&
|
||||
modelIndex < packetsPerModel.length
|
||||
) {
|
||||
packetsPerModel[modelIndex]!.push(typedPacket);
|
||||
if (!erroredModelIndices.has(modelIndex)) {
|
||||
dirtyModelIndices.add(modelIndex);
|
||||
}
|
||||
packetsPerModel[modelIndex] = [
|
||||
...packetsPerModel[modelIndex]!,
|
||||
typedPacket,
|
||||
];
|
||||
|
||||
if (packetObj.type === "citation_info") {
|
||||
const citationInfo = packetObj as {
|
||||
@@ -1187,7 +1048,6 @@ export default function useChatController({
|
||||
// Single-model
|
||||
packets.push(typedPacket);
|
||||
packetsVersion++;
|
||||
singleModelDirty = true;
|
||||
|
||||
if (packetObj.type === "citation_info") {
|
||||
const citationInfo = packetObj as {
|
||||
@@ -1214,16 +1074,73 @@ export default function useChatController({
|
||||
console.warn("Unknown packet:", JSON.stringify(packet));
|
||||
}
|
||||
|
||||
// Mark dirty — flushViaRAF coalesces bursts into one React update per frame.
|
||||
if (!isMultiModel) singleModelDirty = true;
|
||||
pendingFlush = true;
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
parentMessage =
|
||||
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
|
||||
|
||||
// Build the messages to upsert based on single vs multi-model mode
|
||||
let messagesToUpsertInLoop: Message[];
|
||||
|
||||
if (isMultiModel) {
|
||||
// Read the current user node from the tree to preserve childrenNodeIds
|
||||
// (initialUserNode has stale/empty children from creation time).
|
||||
const currentUserNode =
|
||||
currentMessageTreeLocal.get(initialUserNode.nodeId) ||
|
||||
initialUserNode;
|
||||
const updatedUserNode: Message = {
|
||||
...currentUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
};
|
||||
messagesToUpsertInLoop = [
|
||||
updatedUserNode,
|
||||
...buildNonErroredNodes(),
|
||||
];
|
||||
} else {
|
||||
messagesToUpsertInLoop = [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
},
|
||||
{
|
||||
...initialAgentNode,
|
||||
messageId: newAgentMessageId ?? undefined,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetCount: packets.length,
|
||||
processingDurationSeconds:
|
||||
finalMessage?.processing_duration_seconds ??
|
||||
(() => {
|
||||
const startTime = useChatSessionStore
|
||||
.getState()
|
||||
.getStreamingStartTime(frozenSessionId);
|
||||
return startTime
|
||||
? Math.floor((Date.now() - startTime) / 1000)
|
||||
: undefined;
|
||||
})(),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsertInLoop,
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Flush any tail state from the final packet(s) before declaring
|
||||
// the stream complete. Without this, the last ≤1 frame of packets
|
||||
// could get stranded in local state.
|
||||
flushPendingUpdates();
|
||||
|
||||
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
|
||||
// catch block replaces the thinking placeholder with an error message.
|
||||
if (stack.error) {
|
||||
@@ -1257,7 +1174,6 @@ export default function useChatController({
|
||||
errorCode,
|
||||
isRetryable,
|
||||
errorDetails,
|
||||
is_generating: false,
|
||||
})
|
||||
: [
|
||||
{
|
||||
|
||||
@@ -48,7 +48,6 @@ describe("useSettings", () => {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
});
|
||||
@@ -66,7 +65,6 @@ describe("useSettings", () => {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
};
|
||||
|
||||
@@ -23,7 +23,6 @@ const DEFAULT_SETTINGS = {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
} satisfies Settings;
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
||||
// Fixed reveal rate — NOT adaptive. Any ceil(delta/N) formula produces
|
||||
// visible chunks on burst packet arrivals. 1 = 60 cps, 2 = 120 cps.
|
||||
const CHARS_PER_FRAME = 2;
|
||||
|
||||
/**
|
||||
* Reveals `target` one character at a time on each animation frame.
|
||||
* When `enabled` is false (historical messages), snaps to full on mount.
|
||||
* The rAF loop pauses once caught up and resumes when `target` grows.
|
||||
*/
|
||||
export function useTypewriter(target: string, enabled: boolean): string {
|
||||
// Ref so the rAF loop reads latest length without restarting.
|
||||
const targetRef = useRef(target);
|
||||
targetRef.current = target;
|
||||
|
||||
// Mirror `enabled` so the restart effect can short-circuit when the
|
||||
// caller has turned animation off (e.g. voice-mode, where display is
|
||||
// driven by audio position — the typewriter must stay idle and not
|
||||
// animate a jump after audio ends).
|
||||
const enabledRef = useRef(enabled);
|
||||
enabledRef.current = enabled;
|
||||
|
||||
// `enabled` controls initial state: animate from 0 vs snap to full for
|
||||
// history/voice. Transitions mid-stream are handled via enabledRef in
|
||||
// the restart effect so a flip to false doesn't dump the buffered tail
|
||||
// *and* doesn't spin up the rAF loop on later growth.
|
||||
const [displayedLength, setDisplayedLength] = useState<number>(
|
||||
enabled ? 0 : target.length
|
||||
);
|
||||
|
||||
// Mirror displayedLength in a ref so the rAF loop can read the latest
|
||||
// value without stale-closure issues AND without needing a functional
|
||||
// state updater (which must be pure — no ref mutations inside).
|
||||
const displayedLengthRef = useRef(displayedLength);
|
||||
|
||||
// Clamp (not reset) on target shrink — preserves already-revealed chars
|
||||
// across user-cancel freeze and regeneration.
|
||||
const prevTargetLengthRef = useRef(target.length);
|
||||
useEffect(() => {
|
||||
if (target.length < prevTargetLengthRef.current) {
|
||||
const clamped = Math.min(displayedLengthRef.current, target.length);
|
||||
displayedLengthRef.current = clamped;
|
||||
setDisplayedLength(clamped);
|
||||
}
|
||||
prevTargetLengthRef.current = target.length;
|
||||
}, [target.length]);
|
||||
|
||||
// Self-scheduling rAF loop. Pauses when caught up so idle/historical
|
||||
// messages don't run a 60fps no-op updater for their entire lifetime.
|
||||
const rafIdRef = useRef<number | null>(null);
|
||||
const runningRef = useRef(false);
|
||||
const startLoopRef = useRef<(() => void) | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const tick = () => {
|
||||
const targetLen = targetRef.current.length;
|
||||
const prev = displayedLengthRef.current;
|
||||
if (prev >= targetLen) {
|
||||
// Caught up — pause the loop. The sibling effect below will
|
||||
// restart it when `target` grows.
|
||||
runningRef.current = false;
|
||||
rafIdRef.current = null;
|
||||
return;
|
||||
}
|
||||
const next = Math.min(prev + CHARS_PER_FRAME, targetLen);
|
||||
displayedLengthRef.current = next;
|
||||
setDisplayedLength(next);
|
||||
rafIdRef.current = requestAnimationFrame(tick);
|
||||
};
|
||||
|
||||
const start = () => {
|
||||
if (runningRef.current) return;
|
||||
// Animation disabled — snap to full and stay idle. This is the
|
||||
// voice-mode path where content is driven by audio position, and
|
||||
// any "gap" (e.g. user stops audio early) must jump instantly
|
||||
// instead of animating a 1500-char typewriter burst.
|
||||
if (!enabledRef.current) {
|
||||
const targetLen = targetRef.current.length;
|
||||
if (displayedLengthRef.current !== targetLen) {
|
||||
displayedLengthRef.current = targetLen;
|
||||
setDisplayedLength(targetLen);
|
||||
}
|
||||
return;
|
||||
}
|
||||
runningRef.current = true;
|
||||
rafIdRef.current = requestAnimationFrame(tick);
|
||||
};
|
||||
|
||||
startLoopRef.current = start;
|
||||
|
||||
if (targetRef.current.length > displayedLengthRef.current) {
|
||||
start();
|
||||
}
|
||||
|
||||
return () => {
|
||||
runningRef.current = false;
|
||||
if (rafIdRef.current !== null) {
|
||||
cancelAnimationFrame(rafIdRef.current);
|
||||
rafIdRef.current = null;
|
||||
}
|
||||
startLoopRef.current = null;
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Restart the loop when target grows past what's currently displayed.
|
||||
useEffect(() => {
|
||||
if (target.length > displayedLength && startLoopRef.current) {
|
||||
startLoopRef.current();
|
||||
}
|
||||
}, [target.length, displayedLength]);
|
||||
|
||||
return useMemo(
|
||||
() => target.slice(0, Math.min(displayedLength, target.length)),
|
||||
[target, displayedLength]
|
||||
);
|
||||
}
|
||||
@@ -27,7 +27,6 @@ export interface Settings {
|
||||
query_history_type: QueryHistoryType;
|
||||
|
||||
deep_research_enabled?: boolean;
|
||||
multi_model_chat_enabled?: boolean;
|
||||
search_ui_enabled?: boolean;
|
||||
|
||||
// Image processing settings
|
||||
|
||||
@@ -9,7 +9,6 @@ import { useField, useFormikContext } from "formik";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import Label from "@/refresh-components/form/Label";
|
||||
import type { TagProps } from "@opal/components/tag/components";
|
||||
|
||||
interface OrientationLayoutProps {
|
||||
name?: string;
|
||||
@@ -17,8 +16,6 @@ interface OrientationLayoutProps {
|
||||
nonInteractive?: boolean;
|
||||
children?: React.ReactNode;
|
||||
title: string | RichStr;
|
||||
/** Tag rendered inline beside the title (passed through to Content). */
|
||||
tag?: TagProps;
|
||||
description?: string | RichStr;
|
||||
suffix?: "optional" | (string & {});
|
||||
sizePreset?: "main-content" | "main-ui";
|
||||
@@ -131,7 +128,6 @@ function HorizontalInputLayout({
|
||||
children,
|
||||
center,
|
||||
title,
|
||||
tag,
|
||||
description,
|
||||
suffix,
|
||||
sizePreset = "main-content",
|
||||
@@ -148,7 +144,6 @@ function HorizontalInputLayout({
|
||||
title={title}
|
||||
description={description}
|
||||
suffix={suffix}
|
||||
tag={tag}
|
||||
sizePreset={sizePreset}
|
||||
variant="section"
|
||||
widthVariant="full"
|
||||
|
||||
@@ -694,25 +694,6 @@ export function useLlmManager(
|
||||
prevAgentIdRef.current = liveAgent?.id;
|
||||
}, [liveAgent?.id]);
|
||||
|
||||
// Clear manual override when arriving at a *different* existing session
|
||||
// from any previously-seen defined session. Tracks only the last
|
||||
// *defined* session id so a round-trip through new-chat (A → undefined
|
||||
// → B) still resets, while A → undefined (new-chat) preserves it.
|
||||
const prevDefinedSessionIdRef = useRef<string | undefined>(undefined);
|
||||
useEffect(() => {
|
||||
const nextId = currentChatSession?.id;
|
||||
if (
|
||||
nextId !== undefined &&
|
||||
prevDefinedSessionIdRef.current !== undefined &&
|
||||
nextId !== prevDefinedSessionIdRef.current
|
||||
) {
|
||||
setUserHasManuallyOverriddenLLM(false);
|
||||
}
|
||||
if (nextId !== undefined) {
|
||||
prevDefinedSessionIdRef.current = nextId;
|
||||
}
|
||||
}, [currentChatSession?.id]);
|
||||
|
||||
function getValidLlmDescriptor(
|
||||
modelName: string | null | undefined
|
||||
): LlmDescriptor {
|
||||
@@ -734,9 +715,8 @@ export function useLlmManager(
|
||||
|
||||
if (llmProviders === undefined || llmProviders === null) {
|
||||
resolved = manualLlm;
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
// Manual override wins over session's `current_alternate_model`.
|
||||
// Cleared on cross-session navigation by the effect above.
|
||||
} else if (userHasManuallyOverriddenLLM && !currentChatSession) {
|
||||
// User has overridden in this session and switched to a new session
|
||||
resolved = manualLlm;
|
||||
} else if (currentChatSession?.current_alternate_model) {
|
||||
resolved = getValidLlmDescriptorForProviders(
|
||||
@@ -748,6 +728,8 @@ export function useLlmManager(
|
||||
liveAgent.llm_model_version_override,
|
||||
llmProviders
|
||||
);
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
resolved = manualLlm;
|
||||
} else if (user?.preferences?.default_model) {
|
||||
resolved = getValidLlmDescriptorForProviders(
|
||||
user.preferences.default_model,
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { SvgCpu, SvgPlug, SvgServer } from "@opal/icons";
|
||||
import {
|
||||
SvgBifrost,
|
||||
SvgOpenai,
|
||||
SvgClaude,
|
||||
SvgOllama,
|
||||
SvgAws,
|
||||
SvgOpenrouter,
|
||||
SvgAzure,
|
||||
SvgGemini,
|
||||
SvgLitellm,
|
||||
SvgLmStudio,
|
||||
SvgMicrosoft,
|
||||
SvgMistral,
|
||||
SvgDeepseek,
|
||||
SvgQwen,
|
||||
SvgGoogle,
|
||||
} from "@opal/logos";
|
||||
import { ZAIIcon } from "@/components/icons/icons";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import type { LLMProviderView } from "@/interfaces/llm";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import AnthropicModal from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import OllamaModal from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import AzureModal from "@/sections/modals/llmConfig/AzureModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
// ─── Text (LLM) providers ────────────────────────────────────────────────────
|
||||
|
||||
export interface ProviderEntry {
|
||||
icon: IconFunctionComponent;
|
||||
productName: string;
|
||||
companyName: string;
|
||||
Modal: React.ComponentType<LLMProviderFormProps>;
|
||||
}
|
||||
|
||||
const PROVIDERS: Record<string, ProviderEntry> = {
|
||||
[LLMProviderName.OPENAI]: {
|
||||
icon: SvgOpenai,
|
||||
productName: "GPT",
|
||||
companyName: "OpenAI",
|
||||
Modal: OpenAIModal,
|
||||
},
|
||||
[LLMProviderName.ANTHROPIC]: {
|
||||
icon: SvgClaude,
|
||||
productName: "Claude",
|
||||
companyName: "Anthropic",
|
||||
Modal: AnthropicModal,
|
||||
},
|
||||
[LLMProviderName.VERTEX_AI]: {
|
||||
icon: SvgGemini,
|
||||
productName: "Gemini",
|
||||
companyName: "Google Cloud Vertex AI",
|
||||
Modal: VertexAIModal,
|
||||
},
|
||||
[LLMProviderName.BEDROCK]: {
|
||||
icon: SvgAws,
|
||||
productName: "Amazon Bedrock",
|
||||
companyName: "AWS",
|
||||
Modal: BedrockModal,
|
||||
},
|
||||
[LLMProviderName.AZURE]: {
|
||||
icon: SvgAzure,
|
||||
productName: "Azure OpenAI",
|
||||
companyName: "Microsoft Azure",
|
||||
Modal: AzureModal,
|
||||
},
|
||||
[LLMProviderName.LITELLM]: {
|
||||
icon: SvgLitellm,
|
||||
productName: "LiteLLM",
|
||||
companyName: "LiteLLM",
|
||||
Modal: CustomModal,
|
||||
},
|
||||
[LLMProviderName.LITELLM_PROXY]: {
|
||||
icon: SvgLitellm,
|
||||
productName: "LiteLLM Proxy",
|
||||
companyName: "LiteLLM Proxy",
|
||||
Modal: LiteLLMProxyModal,
|
||||
},
|
||||
[LLMProviderName.OLLAMA_CHAT]: {
|
||||
icon: SvgOllama,
|
||||
productName: "Ollama",
|
||||
companyName: "Ollama",
|
||||
Modal: OllamaModal,
|
||||
},
|
||||
[LLMProviderName.OPENROUTER]: {
|
||||
icon: SvgOpenrouter,
|
||||
productName: "OpenRouter",
|
||||
companyName: "OpenRouter",
|
||||
Modal: OpenRouterModal,
|
||||
},
|
||||
[LLMProviderName.LM_STUDIO]: {
|
||||
icon: SvgLmStudio,
|
||||
productName: "LM Studio",
|
||||
companyName: "LM Studio",
|
||||
Modal: LMStudioModal,
|
||||
},
|
||||
[LLMProviderName.BIFROST]: {
|
||||
icon: SvgBifrost,
|
||||
productName: "Bifrost",
|
||||
companyName: "Bifrost",
|
||||
Modal: BifrostModal,
|
||||
},
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: {
|
||||
icon: SvgPlug,
|
||||
productName: "OpenAI-Compatible",
|
||||
companyName: "OpenAI-Compatible",
|
||||
Modal: OpenAICompatibleModal,
|
||||
},
|
||||
[LLMProviderName.CUSTOM]: {
|
||||
icon: SvgServer,
|
||||
productName: "Custom Models",
|
||||
companyName: "models from other LiteLLM-compatible providers",
|
||||
Modal: CustomModal,
|
||||
},
|
||||
};
|
||||
|
||||
const DEFAULT_ENTRY: ProviderEntry = {
|
||||
icon: SvgCpu,
|
||||
productName: "",
|
||||
companyName: "",
|
||||
Modal: CustomModal,
|
||||
};
|
||||
|
||||
// Providers that don't use custom_config themselves — if custom_config is
|
||||
// present it means the provider was originally created via CustomModal.
|
||||
const CUSTOM_CONFIG_OVERRIDES = new Set<string>([
|
||||
LLMProviderName.OPENAI,
|
||||
LLMProviderName.ANTHROPIC,
|
||||
LLMProviderName.AZURE,
|
||||
LLMProviderName.OPENROUTER,
|
||||
]);
|
||||
|
||||
export function getProvider(
|
||||
providerName: string,
|
||||
existingProvider?: LLMProviderView
|
||||
): ProviderEntry {
|
||||
const entry = PROVIDERS[providerName] ?? {
|
||||
...DEFAULT_ENTRY,
|
||||
productName: providerName,
|
||||
companyName: providerName,
|
||||
};
|
||||
|
||||
if (
|
||||
existingProvider?.custom_config != null &&
|
||||
CUSTOM_CONFIG_OVERRIDES.has(providerName)
|
||||
) {
|
||||
return { ...entry, Modal: CustomModal };
|
||||
}
|
||||
|
||||
return entry;
|
||||
}
|
||||
|
||||
// ─── Aggregator providers ────────────────────────────────────────────────────
|
||||
// Providers that host models from multiple vendors (e.g. Bedrock hosts Claude,
|
||||
// Llama, etc.) Used by the model-icon resolver to prioritise vendor icons.
|
||||
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
LLMProviderName.BEDROCK,
|
||||
"bedrock_converse",
|
||||
LLMProviderName.OPENROUTER,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
LLMProviderName.BIFROST,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
]);
|
||||
|
||||
// ─── Model-aware icon resolver ───────────────────────────────────────────────
|
||||
|
||||
const MODEL_ICON_MAP: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENAI]: SvgOpenai,
|
||||
[LLMProviderName.ANTHROPIC]: SvgClaude,
|
||||
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.VERTEX_AI]: SvgGemini,
|
||||
[LLMProviderName.BEDROCK]: SvgAws,
|
||||
[LLMProviderName.LITELLM_PROXY]: SvgLitellm,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
amazon: SvgAws,
|
||||
phi: SvgMicrosoft,
|
||||
mistral: SvgMistral,
|
||||
ministral: SvgMistral,
|
||||
llama: SvgCpu,
|
||||
ollama: SvgOllama,
|
||||
gemini: SvgGemini,
|
||||
deepseek: SvgDeepseek,
|
||||
claude: SvgClaude,
|
||||
azure: SvgAzure,
|
||||
microsoft: SvgMicrosoft,
|
||||
meta: SvgCpu,
|
||||
google: SvgGoogle,
|
||||
qwen: SvgQwen,
|
||||
qwq: SvgQwen,
|
||||
zai: ZAIIcon,
|
||||
bedrock_converse: SvgAws,
|
||||
};
|
||||
|
||||
/**
|
||||
* Model-aware icon resolver that checks both provider name and model name
|
||||
* to pick the most specific icon (e.g. Claude icon for a Bedrock Claude model).
|
||||
*/
|
||||
export function getModelIcon(
|
||||
providerName: string,
|
||||
modelName?: string
|
||||
): IconFunctionComponent {
|
||||
const lowerProviderName = providerName.toLowerCase();
|
||||
|
||||
// For aggregator providers, prioritise showing the vendor icon based on model name
|
||||
if (AGGREGATOR_PROVIDERS.has(lowerProviderName) && modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(MODEL_ICON_MAP)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if provider name directly matches an icon
|
||||
if (lowerProviderName in MODEL_ICON_MAP) {
|
||||
const icon = MODEL_ICON_MAP[lowerProviderName];
|
||||
if (icon) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
|
||||
// For non-aggregator providers, check if model name contains any of the keys
|
||||
if (modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(MODEL_ICON_MAP)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to CPU icon if no matches
|
||||
return SvgCpu;
|
||||
}
|
||||
176
web/src/lib/llmConfig/providers.ts
Normal file
176
web/src/lib/llmConfig/providers.ts
Normal file
@@ -0,0 +1,176 @@
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { SvgCpu, SvgPlug, SvgServer } from "@opal/icons";
|
||||
import {
|
||||
SvgBifrost,
|
||||
SvgOpenai,
|
||||
SvgClaude,
|
||||
SvgOllama,
|
||||
SvgAws,
|
||||
SvgOpenrouter,
|
||||
SvgAzure,
|
||||
SvgGemini,
|
||||
SvgLitellm,
|
||||
SvgLmStudio,
|
||||
SvgMicrosoft,
|
||||
SvgMistral,
|
||||
SvgDeepseek,
|
||||
SvgQwen,
|
||||
SvgGoogle,
|
||||
} from "@opal/logos";
|
||||
import { ZAIIcon } from "@/components/icons/icons";
|
||||
import { LLMProviderName } from "@/interfaces/llm";
|
||||
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
LLMProviderName.BEDROCK,
|
||||
"bedrock_converse",
|
||||
LLMProviderName.OPENROUTER,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
LLMProviderName.BIFROST,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
]);
|
||||
|
||||
const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENAI]: SvgOpenai,
|
||||
[LLMProviderName.ANTHROPIC]: SvgClaude,
|
||||
[LLMProviderName.VERTEX_AI]: SvgGemini,
|
||||
[LLMProviderName.BEDROCK]: SvgAws,
|
||||
[LLMProviderName.AZURE]: SvgAzure,
|
||||
[LLMProviderName.LITELLM]: SvgLitellm,
|
||||
[LLMProviderName.LITELLM_PROXY]: SvgLitellm,
|
||||
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: SvgServer,
|
||||
};
|
||||
|
||||
const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENAI]: "GPT",
|
||||
[LLMProviderName.ANTHROPIC]: "Claude",
|
||||
[LLMProviderName.VERTEX_AI]: "Gemini",
|
||||
[LLMProviderName.BEDROCK]: "Amazon Bedrock",
|
||||
[LLMProviderName.AZURE]: "Azure OpenAI",
|
||||
[LLMProviderName.LITELLM]: "LiteLLM",
|
||||
[LLMProviderName.LITELLM_PROXY]: "LiteLLM Proxy",
|
||||
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI-Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Custom Models",
|
||||
};
|
||||
|
||||
const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENAI]: "OpenAI",
|
||||
[LLMProviderName.ANTHROPIC]: "Anthropic",
|
||||
[LLMProviderName.VERTEX_AI]: "Google Cloud Vertex AI",
|
||||
[LLMProviderName.BEDROCK]: "AWS",
|
||||
[LLMProviderName.AZURE]: "Microsoft Azure",
|
||||
[LLMProviderName.LITELLM]: "LiteLLM",
|
||||
[LLMProviderName.LITELLM_PROXY]: "LiteLLM Proxy",
|
||||
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI-Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "models from other LiteLLM-compatible providers",
|
||||
};
|
||||
|
||||
export function getProviderProductName(providerName: string): string {
|
||||
return PROVIDER_PRODUCT_NAMES[providerName] ?? providerName;
|
||||
}
|
||||
|
||||
export function getProviderDisplayName(providerName: string): string {
|
||||
return PROVIDER_DISPLAY_NAMES[providerName] ?? providerName;
|
||||
}
|
||||
|
||||
export function getProviderIcon(providerName: string): IconFunctionComponent {
|
||||
return PROVIDER_ICONS[providerName] ?? SvgCpu;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model-aware icon resolver (legacy icon set)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const MODEL_ICON_MAP: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENAI]: SvgOpenai,
|
||||
[LLMProviderName.ANTHROPIC]: SvgClaude,
|
||||
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.VERTEX_AI]: SvgGemini,
|
||||
[LLMProviderName.BEDROCK]: SvgAws,
|
||||
[LLMProviderName.LITELLM_PROXY]: SvgLitellm,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
amazon: SvgAws,
|
||||
phi: SvgMicrosoft,
|
||||
mistral: SvgMistral,
|
||||
ministral: SvgMistral,
|
||||
llama: SvgCpu,
|
||||
ollama: SvgOllama,
|
||||
gemini: SvgGemini,
|
||||
deepseek: SvgDeepseek,
|
||||
claude: SvgClaude,
|
||||
azure: SvgAzure,
|
||||
microsoft: SvgMicrosoft,
|
||||
meta: SvgCpu,
|
||||
google: SvgGoogle,
|
||||
qwen: SvgQwen,
|
||||
qwq: SvgQwen,
|
||||
zai: ZAIIcon,
|
||||
bedrock_converse: SvgAws,
|
||||
};
|
||||
|
||||
/**
|
||||
* Model-aware icon resolver that checks both provider name and model name
|
||||
* to pick the most specific icon (e.g. Claude icon for a Bedrock Claude model).
|
||||
*/
|
||||
export const getModelIcon = (
|
||||
providerName: string,
|
||||
modelName?: string
|
||||
): IconFunctionComponent => {
|
||||
const lowerProviderName = providerName.toLowerCase();
|
||||
|
||||
// For aggregator providers, prioritise showing the vendor icon based on model name
|
||||
if (AGGREGATOR_PROVIDERS.has(lowerProviderName) && modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(MODEL_ICON_MAP)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if provider name directly matches an icon
|
||||
if (lowerProviderName in MODEL_ICON_MAP) {
|
||||
const icon = MODEL_ICON_MAP[lowerProviderName];
|
||||
if (icon) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
|
||||
// For non-aggregator providers, check if model name contains any of the keys
|
||||
if (modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(MODEL_ICON_MAP)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to CPU icon if no matches
|
||||
return SvgCpu;
|
||||
};
|
||||
@@ -44,7 +44,7 @@ export function getFinalLLM(
|
||||
return [provider, model];
|
||||
}
|
||||
|
||||
export function getProviderOverrideForPersona(
|
||||
export function getLLMProviderOverrideForPersona(
|
||||
liveAgent: MinimalPersonaSnapshot,
|
||||
llmProviders: LLMProviderDescriptor[]
|
||||
): LlmDescriptor | null {
|
||||
@@ -144,7 +144,7 @@ export function getDisplayName(
|
||||
agent: MinimalPersonaSnapshot,
|
||||
llmProviders: LLMProviderDescriptor[]
|
||||
): string | undefined {
|
||||
const llmDescriptor = getProviderOverrideForPersona(
|
||||
const llmDescriptor = getLLMProviderOverrideForPersona(
|
||||
agent,
|
||||
llmProviders ?? []
|
||||
);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import copy from "copy-to-clipboard";
|
||||
import { Button, ButtonProps } from "@opal/components";
|
||||
import { SvgAlertTriangle, SvgCheck, SvgCopy } from "@opal/icons";
|
||||
|
||||
@@ -41,19 +40,26 @@ export default function CopyIconButton({
|
||||
}
|
||||
|
||||
try {
|
||||
if (navigator.clipboard && getHtmlContent) {
|
||||
// Check if Clipboard API is available
|
||||
if (!navigator.clipboard) {
|
||||
throw new Error("Clipboard API not available");
|
||||
}
|
||||
|
||||
// If HTML content getter is provided, copy both HTML and plain text
|
||||
if (getHtmlContent) {
|
||||
const htmlContent = getHtmlContent();
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob([htmlContent], { type: "text/html" }),
|
||||
"text/plain": new Blob([text], { type: "text/plain" }),
|
||||
});
|
||||
await navigator.clipboard.write([clipboardItem]);
|
||||
} else if (navigator.clipboard) {
|
||||
}
|
||||
// Default: plain text only
|
||||
else {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} else if (!copy(text)) {
|
||||
throw new Error("copy-to-clipboard returned false");
|
||||
}
|
||||
|
||||
// Show "copied" state
|
||||
setCopyState("copied");
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useState, useEffect, useCallback, useMemo, useRef } from "react";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { structureValue } from "@/lib/llmConfig/utils";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import { AGGREGATOR_PROVIDERS } from "@/lib/llmConfig/svc";
|
||||
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
import { useState, useMemo, useRef } from "react";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import { Button, SelectButton, OpenButton } from "@opal/components";
|
||||
import { SvgPlusCircle, SvgX } from "@opal/icons";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { LLMOption } from "@/refresh-components/popovers/interfaces";
|
||||
import ModelListContent from "@/refresh-components/popovers/ModelListContent";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
@@ -45,12 +44,8 @@ export default function ModelSelector({
|
||||
// Virtual anchor ref — points to the clicked pill so the popover positions above it
|
||||
const anchorRef = useRef<HTMLElement | null>(null);
|
||||
|
||||
const settings = useSettingsContext();
|
||||
const multiModelAllowed =
|
||||
settings?.settings?.multi_model_chat_enabled ?? true;
|
||||
|
||||
const isMultiModel = selectedModels.length > 1;
|
||||
const atMax = selectedModels.length >= MAX_MODELS || !multiModelAllowed;
|
||||
const atMax = selectedModels.length >= MAX_MODELS;
|
||||
|
||||
const selectedKeys = useMemo(
|
||||
() => new Set(selectedModels.map((m) => modelKey(m.provider, m.modelName))),
|
||||
@@ -109,7 +104,6 @@ export default function ModelSelector({
|
||||
onRemove(existingIndex);
|
||||
} else if (!atMax) {
|
||||
onAdd(model);
|
||||
setOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -164,12 +158,9 @@ export default function ModelSelector({
|
||||
);
|
||||
|
||||
if (!isMultiModel) {
|
||||
// Stable key — keying on model would unmount the pill
|
||||
// on change and leave Radix's anchorRef detached,
|
||||
// flashing the closing popover at (0,0).
|
||||
return (
|
||||
<OpenButton
|
||||
key="single-model-pill"
|
||||
key={modelKey(model.provider, model.modelName)}
|
||||
icon={ProviderIcon}
|
||||
onClick={(e: React.MouseEvent) =>
|
||||
handlePillClick(index, e.currentTarget as HTMLElement)
|
||||
@@ -223,17 +214,15 @@ export default function ModelSelector({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{!(atMax && replacingIndex === null) && (
|
||||
<Popover.Content side="top" align="end" width="lg">
|
||||
<ModelListContent
|
||||
llmProviders={llmManager.llmProviders}
|
||||
isLoading={llmManager.isLoadingProviders}
|
||||
onSelect={handleSelect}
|
||||
isSelected={isSelected}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Popover.Content>
|
||||
)}
|
||||
<Popover.Content side="top" align="end" width="lg">
|
||||
<ModelListContent
|
||||
llmProviders={llmManager.llmProviders}
|
||||
isLoading={llmManager.isLoadingProviders}
|
||||
onSelect={handleSelect}
|
||||
isSelected={isSelected}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -400,22 +400,19 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const multiModel = useMultiModelChat(llmManager);
|
||||
|
||||
// Auto-fold sidebar when a multi-model message is submitted.
|
||||
// Stays collapsed until the user exits multi-model mode (removes models).
|
||||
// Auto-fold sidebar when multi-model is active (panels need full width)
|
||||
const { folded: sidebarFolded, setFolded: setSidebarFolded } =
|
||||
useSidebarState();
|
||||
const preMultiModelFoldedRef = useRef<boolean | null>(null);
|
||||
|
||||
const foldSidebarForMultiModel = useCallback(() => {
|
||||
if (preMultiModelFoldedRef.current === null) {
|
||||
preMultiModelFoldedRef.current = sidebarFolded;
|
||||
setSidebarFolded(true);
|
||||
}
|
||||
}, [sidebarFolded, setSidebarFolded]);
|
||||
|
||||
// Restore sidebar when user exits multi-model mode
|
||||
useEffect(() => {
|
||||
if (
|
||||
multiModel.isMultiModelActive &&
|
||||
preMultiModelFoldedRef.current === null
|
||||
) {
|
||||
preMultiModelFoldedRef.current = sidebarFolded;
|
||||
setSidebarFolded(true);
|
||||
} else if (
|
||||
!multiModel.isMultiModelActive &&
|
||||
preMultiModelFoldedRef.current !== null
|
||||
) {
|
||||
@@ -425,27 +422,16 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [multiModel.isMultiModelActive]);
|
||||
|
||||
// Sync single-model selection to llmManager so the submission path uses
|
||||
// the correct provider/version. Guard against echoing derived state back
|
||||
// — only call updateCurrentLlm when the selection actually differs from
|
||||
// currentLlm, otherwise the initial [] → [currentLlmModel] sync would
|
||||
// pin `userHasManuallyOverriddenLLM=true` with whatever was resolved
|
||||
// first (often the default model before the session's alt_model loads).
|
||||
// Sync single-model selection to llmManager so the submission path
|
||||
// uses the correct provider/version (replaces the old LLMPopover sync).
|
||||
useEffect(() => {
|
||||
if (multiModel.selectedModels.length === 1) {
|
||||
const model = multiModel.selectedModels[0]!;
|
||||
const current = llmManager.currentLlm;
|
||||
if (
|
||||
model.provider !== current.provider ||
|
||||
model.modelName !== current.modelName ||
|
||||
model.name !== current.name
|
||||
) {
|
||||
llmManager.updateCurrentLlm({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
modelName: model.modelName,
|
||||
});
|
||||
}
|
||||
llmManager.updateCurrentLlm({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
modelName: model.modelName,
|
||||
});
|
||||
}
|
||||
}, [multiModel.selectedModels]);
|
||||
|
||||
@@ -522,8 +508,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onSubmit({
|
||||
message: lastUserMsg.message,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch:
|
||||
deepResearchEnabledForCurrentWorkflow && !multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
messageIdToResend: lastUserMsg.messageId,
|
||||
});
|
||||
}, [
|
||||
@@ -531,7 +516,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabledForCurrentWorkflow,
|
||||
multiModel.isMultiModelActive,
|
||||
]);
|
||||
|
||||
const toggleDocumentSidebar = useCallback(() => {
|
||||
@@ -548,16 +532,11 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const onChat = useCallback(
|
||||
(message: string) => {
|
||||
if (multiModel.isMultiModelActive) {
|
||||
foldSidebarForMultiModel();
|
||||
}
|
||||
resetInputBar();
|
||||
onSubmit({
|
||||
message,
|
||||
currentMessageFiles,
|
||||
deepResearch:
|
||||
deepResearchEnabledForCurrentWorkflow &&
|
||||
!multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
selectedModels: multiModel.isMultiModelActive
|
||||
? multiModel.selectedModels
|
||||
: undefined,
|
||||
@@ -573,7 +552,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
deepResearchEnabledForCurrentWorkflow,
|
||||
multiModel.isMultiModelActive,
|
||||
multiModel.selectedModels,
|
||||
foldSidebarForMultiModel,
|
||||
showOnboarding,
|
||||
onboardingDismissed,
|
||||
finishOnboarding,
|
||||
@@ -611,9 +589,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onSubmit({
|
||||
message,
|
||||
currentMessageFiles,
|
||||
deepResearch:
|
||||
deepResearchEnabledForCurrentWorkflow &&
|
||||
!multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabledForCurrentWorkflow,
|
||||
selectedModels: multiModel.isMultiModelActive
|
||||
? multiModel.selectedModels
|
||||
: undefined,
|
||||
@@ -888,20 +864,13 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
agent={liveAgent}
|
||||
isDefaultAgent={isDefaultAgent}
|
||||
/>
|
||||
{!isSearch &&
|
||||
!(
|
||||
state.phase === "idle" && state.appMode === "search"
|
||||
) &&
|
||||
liveAgent &&
|
||||
!llmManager.isLoadingProviders && (
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
)}
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
</Section>
|
||||
<Spacer rem={1.5} />
|
||||
</Fade>
|
||||
@@ -967,26 +936,23 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
isSearch ? "h-[14px]" : "h-0"
|
||||
)}
|
||||
/>
|
||||
{appFocus.isChat() &&
|
||||
liveAgent &&
|
||||
!llmManager.isLoadingProviders && (
|
||||
<div className="pb-1">
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{appFocus.isChat() && (
|
||||
<div className="pb-1">
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<AppInputBar
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={
|
||||
deepResearchEnabledForCurrentWorkflow
|
||||
}
|
||||
toggleDeepResearch={toggleDeepResearch}
|
||||
isMultiModelActive={multiModel.isMultiModelActive}
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
initialMessage={
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import { markdown } from "@opal/utils";
|
||||
import React, { useCallback, useEffect, useRef, useState } from "react";
|
||||
import React, { useCallback, useRef, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Formik, Form } from "formik";
|
||||
import { Formik, Form, useFormikContext } from "formik";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
@@ -14,9 +14,10 @@ import Card from "@/refresh-components/cards/Card";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import SimpleCollapsible from "@/refresh-components/SimpleCollapsible";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import SwitchField from "@/refresh-components/form/SwitchField";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import InputTextAreaField from "@/refresh-components/form/InputTextAreaField";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
|
||||
import InputSelectField from "@/refresh-components/form/InputSelectField";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import {
|
||||
SvgAddLines,
|
||||
@@ -56,6 +57,7 @@ import * as ActionsLayouts from "@/layouts/actions-layouts";
|
||||
import { getActionIcon } from "@/lib/tools/mcpUtils";
|
||||
import { Disabled, Hoverable } from "@opal/core";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import useFilter from "@/hooks/useFilter";
|
||||
import { MCPServer } from "@/lib/tools/interfaces";
|
||||
import type { IconProps } from "@opal/types";
|
||||
@@ -68,6 +70,26 @@ interface DefaultAgentConfiguration {
|
||||
default_system_prompt: string;
|
||||
}
|
||||
|
||||
interface ChatPreferencesFormValues {
|
||||
// Features
|
||||
search_ui_enabled: boolean;
|
||||
deep_research_enabled: boolean;
|
||||
auto_scroll: boolean;
|
||||
|
||||
// Team context
|
||||
company_name: string;
|
||||
company_description: string;
|
||||
|
||||
// Advanced
|
||||
maximum_chat_retention_days: string;
|
||||
anonymous_user_enabled: boolean;
|
||||
disable_default_assistant: boolean;
|
||||
|
||||
// File limits
|
||||
user_file_max_upload_size_mb: string;
|
||||
file_token_count_threshold_k: string;
|
||||
}
|
||||
|
||||
interface MCPServerCardTool {
|
||||
id: number;
|
||||
icon: React.FunctionComponent<IconProps>;
|
||||
@@ -176,7 +198,6 @@ type FileLimitFieldName =
|
||||
|
||||
interface NumericLimitFieldProps {
|
||||
name: FileLimitFieldName;
|
||||
initialValue: string;
|
||||
defaultValue: string;
|
||||
saveSettings: (updates: Partial<Settings>) => Promise<void>;
|
||||
maxValue?: number;
|
||||
@@ -185,15 +206,16 @@ interface NumericLimitFieldProps {
|
||||
|
||||
function NumericLimitField({
|
||||
name,
|
||||
initialValue: initialValueProp,
|
||||
defaultValue,
|
||||
saveSettings,
|
||||
maxValue,
|
||||
allowZero = false,
|
||||
}: NumericLimitFieldProps) {
|
||||
const [value, setValue] = useState(initialValueProp);
|
||||
const savedValue = useRef(initialValueProp);
|
||||
const { values, setFieldValue } =
|
||||
useFormikContext<ChatPreferencesFormValues>();
|
||||
const initialValue = useRef(values[name]);
|
||||
const restoringRef = useRef(false);
|
||||
const value = values[name];
|
||||
|
||||
const parsed = parseInt(value, 10);
|
||||
const isOverMax =
|
||||
@@ -201,8 +223,8 @@ function NumericLimitField({
|
||||
|
||||
const handleRestore = () => {
|
||||
restoringRef.current = true;
|
||||
savedValue.current = defaultValue;
|
||||
setValue(defaultValue);
|
||||
initialValue.current = defaultValue;
|
||||
void setFieldValue(name, defaultValue);
|
||||
void saveSettings({ [name]: parseInt(defaultValue, 10) });
|
||||
};
|
||||
|
||||
@@ -220,11 +242,11 @@ function NumericLimitField({
|
||||
if (!isValid) {
|
||||
if (allowZero) {
|
||||
// Empty/invalid means "no limit" — persist 0 and clear the field.
|
||||
setValue("");
|
||||
void setFieldValue(name, "");
|
||||
void saveSettings({ [name]: 0 });
|
||||
savedValue.current = "";
|
||||
initialValue.current = "";
|
||||
} else {
|
||||
setValue(savedValue.current);
|
||||
void setFieldValue(name, initialValue.current);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -237,10 +259,10 @@ function NumericLimitField({
|
||||
// For allowZero fields, 0 means "no limit" — clear the display
|
||||
// so the "No limit" placeholder is visible, but still persist 0.
|
||||
if (allowZero && parsed === 0) {
|
||||
setValue("");
|
||||
if (savedValue.current !== "") {
|
||||
void setFieldValue(name, "");
|
||||
if (initialValue.current !== "") {
|
||||
void saveSettings({ [name]: 0 });
|
||||
savedValue.current = "";
|
||||
initialValue.current = "";
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -249,24 +271,23 @@ function NumericLimitField({
|
||||
|
||||
// Update the display to the canonical form (e.g. strip leading zeros).
|
||||
if (value !== normalizedDisplay) {
|
||||
setValue(normalizedDisplay);
|
||||
void setFieldValue(name, normalizedDisplay);
|
||||
}
|
||||
|
||||
// Persist only when the value actually changed.
|
||||
if (normalizedDisplay !== savedValue.current) {
|
||||
if (normalizedDisplay !== initialValue.current) {
|
||||
void saveSettings({ [name]: parsed });
|
||||
savedValue.current = normalizedDisplay;
|
||||
initialValue.current = normalizedDisplay;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Hoverable.Root group="numericLimit" widthVariant="full">
|
||||
<InputTypeIn
|
||||
<InputTypeInField
|
||||
name={name}
|
||||
inputMode="numeric"
|
||||
showClearButton={false}
|
||||
pattern="[0-9]*"
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
placeholder={allowZero ? "No limit" : `Default: ${defaultValue}`}
|
||||
variant={isOverMax ? "error" : undefined}
|
||||
rightSection={
|
||||
@@ -290,18 +311,14 @@ function NumericLimitField({
|
||||
|
||||
interface FileSizeLimitFieldsProps {
|
||||
saveSettings: (updates: Partial<Settings>) => Promise<void>;
|
||||
initialUploadSizeMb: string;
|
||||
defaultUploadSizeMb: string;
|
||||
initialTokenThresholdK: string;
|
||||
defaultTokenThresholdK: string;
|
||||
maxAllowedUploadSizeMb?: number;
|
||||
}
|
||||
|
||||
function FileSizeLimitFields({
|
||||
saveSettings,
|
||||
initialUploadSizeMb,
|
||||
defaultUploadSizeMb,
|
||||
initialTokenThresholdK,
|
||||
defaultTokenThresholdK,
|
||||
maxAllowedUploadSizeMb,
|
||||
}: FileSizeLimitFieldsProps) {
|
||||
@@ -319,7 +336,6 @@ function FileSizeLimitFields({
|
||||
>
|
||||
<NumericLimitField
|
||||
name="user_file_max_upload_size_mb"
|
||||
initialValue={initialUploadSizeMb}
|
||||
defaultValue={defaultUploadSizeMb}
|
||||
saveSettings={saveSettings}
|
||||
maxValue={maxAllowedUploadSizeMb}
|
||||
@@ -333,7 +349,6 @@ function FileSizeLimitFields({
|
||||
>
|
||||
<NumericLimitField
|
||||
name="file_token_count_threshold_k"
|
||||
initialValue={initialTokenThresholdK}
|
||||
defaultValue={defaultTokenThresholdK}
|
||||
saveSettings={saveSettings}
|
||||
allowZero
|
||||
@@ -344,39 +359,18 @@ function FileSizeLimitFields({
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Inner form component that uses useFormikContext to access values
|
||||
* and create save handlers for settings fields.
|
||||
*/
|
||||
function ChatPreferencesForm() {
|
||||
const router = useRouter();
|
||||
const settings = useSettingsContext();
|
||||
const s = settings.settings;
|
||||
const { values } = useFormikContext<ChatPreferencesFormValues>();
|
||||
|
||||
// Local state for text fields (save-on-blur)
|
||||
const [companyName, setCompanyName] = useState(s.company_name ?? "");
|
||||
const [companyDescription, setCompanyDescription] = useState(
|
||||
s.company_description ?? ""
|
||||
);
|
||||
const savedCompanyName = useRef(companyName);
|
||||
const savedCompanyDescription = useRef(companyDescription);
|
||||
|
||||
// Re-sync local state when settings change externally (e.g. another admin),
|
||||
// but only when there's no in-progress edit (local matches last-saved value).
|
||||
useEffect(() => {
|
||||
const incoming = s.company_name ?? "";
|
||||
if (companyName === savedCompanyName.current && incoming !== companyName) {
|
||||
setCompanyName(incoming);
|
||||
savedCompanyName.current = incoming;
|
||||
}
|
||||
}, [s.company_name]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
useEffect(() => {
|
||||
const incoming = s.company_description ?? "";
|
||||
if (
|
||||
companyDescription === savedCompanyDescription.current &&
|
||||
incoming !== companyDescription
|
||||
) {
|
||||
setCompanyDescription(incoming);
|
||||
savedCompanyDescription.current = incoming;
|
||||
}
|
||||
}, [s.company_description]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
// Track initial text values to avoid unnecessary saves on blur
|
||||
const initialCompanyName = useRef(values.company_name);
|
||||
const initialCompanyDescription = useRef(values.company_description);
|
||||
|
||||
// Tools availability
|
||||
const { tools: availableTools } = useAvailableTools();
|
||||
@@ -532,18 +526,16 @@ function ChatPreferencesForm() {
|
||||
<InputLayouts.Vertical
|
||||
title="Team Name"
|
||||
subDescription="This is added to all chat sessions as additional context to provide a richer/customized experience."
|
||||
nonInteractive
|
||||
>
|
||||
<InputTypeIn
|
||||
<InputTypeInField
|
||||
name="company_name"
|
||||
placeholder="Enter team name"
|
||||
value={companyName}
|
||||
onChange={(e) => setCompanyName(e.target.value)}
|
||||
onBlur={() => {
|
||||
if (companyName !== savedCompanyName.current) {
|
||||
if (values.company_name !== initialCompanyName.current) {
|
||||
void saveSettings({
|
||||
company_name: companyName || null,
|
||||
company_name: values.company_name || null,
|
||||
});
|
||||
savedCompanyName.current = companyName;
|
||||
initialCompanyName.current = values.company_name;
|
||||
}
|
||||
}}
|
||||
/>
|
||||
@@ -552,21 +544,23 @@ function ChatPreferencesForm() {
|
||||
<InputLayouts.Vertical
|
||||
title="Team Context"
|
||||
subDescription="Users can also provide additional individual context in their personal settings."
|
||||
nonInteractive
|
||||
>
|
||||
<InputTextArea
|
||||
<InputTextAreaField
|
||||
name="company_description"
|
||||
placeholder="Describe your team and how Onyx should behave."
|
||||
rows={4}
|
||||
maxRows={10}
|
||||
autoResize
|
||||
value={companyDescription}
|
||||
onChange={(e) => setCompanyDescription(e.target.value)}
|
||||
onBlur={() => {
|
||||
if (companyDescription !== savedCompanyDescription.current) {
|
||||
if (
|
||||
values.company_description !==
|
||||
initialCompanyDescription.current
|
||||
) {
|
||||
void saveSettings({
|
||||
company_description: companyDescription || null,
|
||||
company_description: values.company_description || null,
|
||||
});
|
||||
savedCompanyDescription.current = companyDescription;
|
||||
initialCompanyDescription.current =
|
||||
values.company_description;
|
||||
}
|
||||
}}
|
||||
/>
|
||||
@@ -610,10 +604,9 @@ function ChatPreferencesForm() {
|
||||
title="Search Mode"
|
||||
description="UI mode for quick document search across your organization."
|
||||
disabled={uniqueSources.length === 0}
|
||||
nonInteractive
|
||||
>
|
||||
<Switch
|
||||
checked={s.search_ui_enabled ?? false}
|
||||
<SwitchField
|
||||
name="search_ui_enabled"
|
||||
onCheckedChange={(checked) => {
|
||||
void saveSettings({ search_ui_enabled: checked });
|
||||
}}
|
||||
@@ -623,26 +616,12 @@ function ChatPreferencesForm() {
|
||||
</div>
|
||||
</Disabled>
|
||||
</SimpleTooltip>
|
||||
<InputLayouts.Horizontal
|
||||
title="Multi-Model Generation"
|
||||
tag={{ title: "beta", color: "blue" }}
|
||||
description="Allow multiple models to generate responses in parallel in chat."
|
||||
nonInteractive
|
||||
>
|
||||
<Switch
|
||||
checked={s.multi_model_chat_enabled ?? true}
|
||||
onCheckedChange={(checked) => {
|
||||
void saveSettings({ multi_model_chat_enabled: checked });
|
||||
}}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
<InputLayouts.Horizontal
|
||||
title="Deep Research"
|
||||
description="Agentic research system that works across the web and connected sources. Uses significantly more tokens per query."
|
||||
nonInteractive
|
||||
>
|
||||
<Switch
|
||||
checked={s.deep_research_enabled ?? true}
|
||||
<SwitchField
|
||||
name="deep_research_enabled"
|
||||
onCheckedChange={(checked) => {
|
||||
void saveSettings({ deep_research_enabled: checked });
|
||||
}}
|
||||
@@ -651,10 +630,9 @@ function ChatPreferencesForm() {
|
||||
<InputLayouts.Horizontal
|
||||
title="Chat Auto-Scroll"
|
||||
description="Automatically scroll to new content as chat generates response. Users can override this in their personal settings."
|
||||
nonInteractive
|
||||
>
|
||||
<Switch
|
||||
checked={s.auto_scroll ?? false}
|
||||
<SwitchField
|
||||
name="auto_scroll"
|
||||
onCheckedChange={(checked) => {
|
||||
void saveSettings({ auto_scroll: checked });
|
||||
}}
|
||||
@@ -665,7 +643,7 @@ function ChatPreferencesForm() {
|
||||
|
||||
<Separator noPadding />
|
||||
|
||||
<Disabled disabled={s.disable_default_assistant ?? false}>
|
||||
<Disabled disabled={values.disable_default_assistant}>
|
||||
<div>
|
||||
<Section gap={1.5}>
|
||||
{/* Connectors */}
|
||||
@@ -895,12 +873,9 @@ function ChatPreferencesForm() {
|
||||
<InputLayouts.Horizontal
|
||||
title="Keep Chat History"
|
||||
description="Specify how long Onyx should retain chats in your organization."
|
||||
nonInteractive
|
||||
>
|
||||
<InputSelect
|
||||
value={
|
||||
s.maximum_chat_retention_days?.toString() ?? "forever"
|
||||
}
|
||||
<InputSelectField
|
||||
name="maximum_chat_retention_days"
|
||||
onValueChange={(value) => {
|
||||
void saveSettings({
|
||||
maximum_chat_retention_days:
|
||||
@@ -920,7 +895,7 @@ function ChatPreferencesForm() {
|
||||
365 days
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</InputSelectField>
|
||||
</InputLayouts.Horizontal>
|
||||
</Card>
|
||||
|
||||
@@ -931,29 +906,17 @@ function ChatPreferencesForm() {
|
||||
>
|
||||
<FileSizeLimitFields
|
||||
saveSettings={saveSettings}
|
||||
initialUploadSizeMb={
|
||||
(s.user_file_max_upload_size_mb ?? 0) <= 0
|
||||
? s.default_user_file_max_upload_size_mb?.toString() ??
|
||||
"100"
|
||||
: s.user_file_max_upload_size_mb!.toString()
|
||||
}
|
||||
defaultUploadSizeMb={
|
||||
s.default_user_file_max_upload_size_mb?.toString() ??
|
||||
settings?.settings.default_user_file_max_upload_size_mb?.toString() ??
|
||||
"100"
|
||||
}
|
||||
initialTokenThresholdK={
|
||||
s.file_token_count_threshold_k == null
|
||||
? s.default_file_token_count_threshold_k?.toString() ??
|
||||
"200"
|
||||
: s.file_token_count_threshold_k === 0
|
||||
? ""
|
||||
: s.file_token_count_threshold_k.toString()
|
||||
}
|
||||
defaultTokenThresholdK={
|
||||
s.default_file_token_count_threshold_k?.toString() ??
|
||||
settings?.settings.default_file_token_count_threshold_k?.toString() ??
|
||||
"200"
|
||||
}
|
||||
maxAllowedUploadSizeMb={s.max_allowed_upload_size_mb}
|
||||
maxAllowedUploadSizeMb={
|
||||
settings?.settings.max_allowed_upload_size_mb
|
||||
}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</Card>
|
||||
@@ -962,10 +925,9 @@ function ChatPreferencesForm() {
|
||||
<InputLayouts.Horizontal
|
||||
title="Allow Anonymous Users"
|
||||
description="Allow anyone to start chats without logging in. They do not see any other chats and cannot create agents or update settings."
|
||||
nonInteractive
|
||||
>
|
||||
<Switch
|
||||
checked={s.anonymous_user_enabled ?? false}
|
||||
<SwitchField
|
||||
name="anonymous_user_enabled"
|
||||
onCheckedChange={(checked) => {
|
||||
void saveSettings({ anonymous_user_enabled: checked });
|
||||
}}
|
||||
@@ -975,11 +937,9 @@ function ChatPreferencesForm() {
|
||||
<InputLayouts.Horizontal
|
||||
title="Always Start with an Agent"
|
||||
description="This removes the default chat. Users will always start in an agent, and new chats will be created in their last active agent. Set featured agents to help new users get started."
|
||||
nonInteractive
|
||||
>
|
||||
<Switch
|
||||
id="disable_default_assistant"
|
||||
checked={s.disable_default_assistant ?? false}
|
||||
<SwitchField
|
||||
name="disable_default_assistant"
|
||||
onCheckedChange={(checked) => {
|
||||
void saveSettings({
|
||||
disable_default_assistant: checked,
|
||||
@@ -1082,5 +1042,50 @@ function ChatPreferencesForm() {
|
||||
}
|
||||
|
||||
export default function ChatPreferencesPage() {
|
||||
return <ChatPreferencesForm />;
|
||||
const settings = useSettingsContext();
|
||||
|
||||
const initialValues: ChatPreferencesFormValues = {
|
||||
// Features
|
||||
search_ui_enabled: settings.settings.search_ui_enabled ?? false,
|
||||
deep_research_enabled: settings.settings.deep_research_enabled ?? true,
|
||||
auto_scroll: settings.settings.auto_scroll ?? false,
|
||||
|
||||
// Team context
|
||||
company_name: settings.settings.company_name ?? "",
|
||||
company_description: settings.settings.company_description ?? "",
|
||||
|
||||
// Advanced
|
||||
maximum_chat_retention_days:
|
||||
settings.settings.maximum_chat_retention_days?.toString() ?? "forever",
|
||||
anonymous_user_enabled: settings.settings.anonymous_user_enabled ?? false,
|
||||
disable_default_assistant:
|
||||
settings.settings.disable_default_assistant ?? false,
|
||||
|
||||
// File limits — for upload size: 0/null means "use default";
|
||||
// for token threshold: null means "use default", 0 means "no limit".
|
||||
user_file_max_upload_size_mb:
|
||||
(settings.settings.user_file_max_upload_size_mb ?? 0) <= 0
|
||||
? settings.settings.default_user_file_max_upload_size_mb?.toString() ??
|
||||
"100"
|
||||
: settings.settings.user_file_max_upload_size_mb!.toString(),
|
||||
file_token_count_threshold_k:
|
||||
settings.settings.file_token_count_threshold_k == null
|
||||
? settings.settings.default_file_token_count_threshold_k?.toString() ??
|
||||
"200"
|
||||
: settings.settings.file_token_count_threshold_k === 0
|
||||
? ""
|
||||
: settings.settings.file_token_count_threshold_k.toString(),
|
||||
};
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
onSubmit={() => {}}
|
||||
enableReinitialize
|
||||
>
|
||||
<Form className="h-full w-full">
|
||||
<ChatPreferencesForm />
|
||||
</Form>
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -15,7 +15,11 @@ import { SvgArrowExchange, SvgSettings, SvgTrash } from "@opal/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import { getProvider } from "@/lib/llmConfig";
|
||||
import {
|
||||
getProviderDisplayName,
|
||||
getProviderIcon,
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { deleteLlmProvider, setDefaultLlmModel } from "@/lib/llmConfig/svc";
|
||||
import { Horizontal as HorizontalInput } from "@/layouts/input-layouts";
|
||||
@@ -29,6 +33,19 @@ import {
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import { getModalForExistingProvider } from "@/sections/modals/llmConfig/getModal";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import AnthropicModal from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import OllamaModal from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import AzureModal from "@/sections/modals/llmConfig/AzureModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
@@ -55,6 +72,51 @@ const PROVIDER_DISPLAY_ORDER: string[] = [
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
];
|
||||
|
||||
const PROVIDER_MODAL_MAP: Record<
|
||||
string,
|
||||
(
|
||||
shouldMarkAsDefault: boolean,
|
||||
onOpenChange: (open: boolean) => void
|
||||
) => React.ReactNode
|
||||
> = {
|
||||
openai: (d, onOpenChange) => (
|
||||
<OpenAIModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
anthropic: (d, onOpenChange) => (
|
||||
<AnthropicModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
ollama_chat: (d, onOpenChange) => (
|
||||
<OllamaModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
azure: (d, onOpenChange) => (
|
||||
<AzureModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
bedrock: (d, onOpenChange) => (
|
||||
<BedrockModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
vertex_ai: (d, onOpenChange) => (
|
||||
<VertexAIModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
openrouter: (d, onOpenChange) => (
|
||||
<OpenRouterModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
lm_studio: (d, onOpenChange) => (
|
||||
<LMStudioModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
litellm_proxy: (d, onOpenChange) => (
|
||||
<LiteLLMProxyModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
bifrost: (d, onOpenChange) => (
|
||||
<BifrostModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
openai_compatible: (d, onOpenChange) => (
|
||||
<OpenAICompatibleModal
|
||||
shouldMarkAsDefault={d}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// ExistingProviderCard — card for configured (existing) providers
|
||||
// ============================================================================
|
||||
@@ -63,12 +125,14 @@ interface ExistingProviderCardProps {
|
||||
provider: LLMProviderView;
|
||||
isDefault: boolean;
|
||||
isLastProvider: boolean;
|
||||
defaultModelName?: string;
|
||||
}
|
||||
|
||||
function ExistingProviderCard({
|
||||
provider,
|
||||
isDefault,
|
||||
isLastProvider,
|
||||
defaultModelName,
|
||||
}: ExistingProviderCardProps) {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
@@ -86,14 +150,8 @@ function ExistingProviderCard({
|
||||
}
|
||||
};
|
||||
|
||||
const { icon, companyName, Modal } = getProvider(provider.provider, provider);
|
||||
|
||||
return (
|
||||
<>
|
||||
{isOpen && (
|
||||
<Modal existingLlmProvider={provider} onOpenChange={setIsOpen} />
|
||||
)}
|
||||
|
||||
{deleteModal.isOpen && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgTrash}
|
||||
@@ -144,9 +202,9 @@ function ExistingProviderCard({
|
||||
onClick={() => setIsOpen(true)}
|
||||
>
|
||||
<CardLayout.Header
|
||||
icon={icon}
|
||||
icon={getProviderIcon(provider.provider)}
|
||||
title={provider.name}
|
||||
description={companyName}
|
||||
description={getProviderDisplayName(provider.provider)}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
tag={isDefault ? { title: "Default", color: "blue" } : undefined}
|
||||
@@ -178,6 +236,8 @@ function ExistingProviderCard({
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
{isOpen &&
|
||||
getModalForExistingProvider(provider, setIsOpen, defaultModelName)}
|
||||
</SelectCard>
|
||||
</Hoverable.Root>
|
||||
</>
|
||||
@@ -191,11 +251,18 @@ function ExistingProviderCard({
|
||||
interface NewProviderCardProps {
|
||||
provider: WellKnownLLMProviderDescriptor;
|
||||
isFirstProvider: boolean;
|
||||
formFn: (
|
||||
shouldMarkAsDefault: boolean,
|
||||
onOpenChange: (open: boolean) => void
|
||||
) => React.ReactNode;
|
||||
}
|
||||
|
||||
function NewProviderCard({ provider, isFirstProvider }: NewProviderCardProps) {
|
||||
function NewProviderCard({
|
||||
provider,
|
||||
isFirstProvider,
|
||||
formFn,
|
||||
}: NewProviderCardProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { icon, productName, companyName, Modal } = getProvider(provider.name);
|
||||
|
||||
return (
|
||||
<SelectCard
|
||||
@@ -205,9 +272,9 @@ function NewProviderCard({ provider, isFirstProvider }: NewProviderCardProps) {
|
||||
onClick={() => setIsOpen(true)}
|
||||
>
|
||||
<CardLayout.Header
|
||||
icon={icon}
|
||||
title={productName}
|
||||
description={companyName}
|
||||
icon={getProviderIcon(provider.name)}
|
||||
title={getProviderProductName(provider.name)}
|
||||
description={getProviderDisplayName(provider.name)}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
@@ -223,9 +290,7 @@ function NewProviderCard({ provider, isFirstProvider }: NewProviderCardProps) {
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
{isOpen && (
|
||||
<Modal shouldMarkAsDefault={isFirstProvider} onOpenChange={setIsOpen} />
|
||||
)}
|
||||
{isOpen && formFn(isFirstProvider, setIsOpen)}
|
||||
</SelectCard>
|
||||
);
|
||||
}
|
||||
@@ -242,7 +307,6 @@ function NewCustomProviderCard({
|
||||
isFirstProvider,
|
||||
}: NewCustomProviderCardProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { icon, productName, companyName, Modal } = getProvider("custom");
|
||||
|
||||
return (
|
||||
<SelectCard
|
||||
@@ -252,9 +316,9 @@ function NewCustomProviderCard({
|
||||
onClick={() => setIsOpen(true)}
|
||||
>
|
||||
<CardLayout.Header
|
||||
icon={icon}
|
||||
title={productName}
|
||||
description={companyName}
|
||||
icon={getProviderIcon("custom")}
|
||||
title={getProviderProductName("custom")}
|
||||
description={getProviderDisplayName("custom")}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
@@ -271,7 +335,10 @@ function NewCustomProviderCard({
|
||||
}
|
||||
/>
|
||||
{isOpen && (
|
||||
<Modal shouldMarkAsDefault={isFirstProvider} onOpenChange={setIsOpen} />
|
||||
<CustomModal
|
||||
shouldMarkAsDefault={isFirstProvider}
|
||||
onOpenChange={setIsOpen}
|
||||
/>
|
||||
)}
|
||||
</SelectCard>
|
||||
);
|
||||
@@ -281,7 +348,7 @@ function NewCustomProviderCard({
|
||||
// LLMConfigurationPage — main page component
|
||||
// ============================================================================
|
||||
|
||||
export default function LLMConfigurationPage() {
|
||||
export default function LLMProviderConfigurationPage() {
|
||||
const { mutate } = useSWRConfig();
|
||||
const { llmProviders: existingLlmProviders, defaultText } =
|
||||
useAdminLLMProviders();
|
||||
@@ -402,6 +469,11 @@ export default function LLMConfigurationPage() {
|
||||
provider={provider}
|
||||
isDefault={defaultText?.provider_id === provider.id}
|
||||
isLastProvider={sortedProviders.length === 1}
|
||||
defaultModelName={
|
||||
defaultText?.provider_id === provider.id
|
||||
? defaultText.model_name
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@@ -435,13 +507,23 @@ export default function LLMConfigurationPage() {
|
||||
(bIndex === -1 ? Infinity : bIndex)
|
||||
);
|
||||
})
|
||||
.map((provider) => (
|
||||
<NewProviderCard
|
||||
key={provider.name}
|
||||
provider={provider}
|
||||
isFirstProvider={isFirstProvider}
|
||||
/>
|
||||
))}
|
||||
.map((provider) => {
|
||||
const formFn = PROVIDER_MODAL_MAP[provider.name];
|
||||
if (!formFn) {
|
||||
toast.error(
|
||||
`No modal mapping for provider "${provider.name}".`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<NewProviderCard
|
||||
key={provider.name}
|
||||
provider={provider}
|
||||
isFirstProvider={isFirstProvider}
|
||||
formFn={formFn}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
<NewCustomProviderCard isFirstProvider={isFirstProvider} />
|
||||
</div>
|
||||
</GeneralLayouts.Section>
|
||||
@@ -213,12 +213,9 @@ const ChatScrollContainer = React.memo(
|
||||
}
|
||||
}, [updateScrollState, getScrollState]);
|
||||
|
||||
// MutationObserver (structural) + ResizeObserver (height growth).
|
||||
// NOT characterData — typewriter reveals don't change scrollHeight
|
||||
// and firing per-char thrashed auto-scroll.
|
||||
// Watch for content changes (MutationObserver + ResizeObserver)
|
||||
useEffect(() => {
|
||||
const container = scrollContainerRef.current;
|
||||
const contentWrapper = contentWrapperRef.current;
|
||||
if (!container) return;
|
||||
|
||||
let rafId: number | null = null;
|
||||
@@ -247,17 +244,17 @@ const ChatScrollContainer = React.memo(
|
||||
});
|
||||
};
|
||||
|
||||
// MutationObserver for content changes
|
||||
const mutationObserver = new MutationObserver(onContentChange);
|
||||
mutationObserver.observe(container, {
|
||||
childList: true,
|
||||
subtree: true,
|
||||
characterData: true,
|
||||
});
|
||||
|
||||
// ResizeObserver for container size changes
|
||||
const resizeObserver = new ResizeObserver(onContentChange);
|
||||
resizeObserver.observe(container);
|
||||
if (contentWrapper) {
|
||||
resizeObserver.observe(contentWrapper);
|
||||
}
|
||||
|
||||
return () => {
|
||||
mutationObserver.disconnect();
|
||||
@@ -355,7 +352,6 @@ const ChatScrollContainer = React.memo(
|
||||
key={sessionId}
|
||||
ref={scrollContainerRef}
|
||||
data-testid="chat-scroll-container"
|
||||
data-chat-scroll
|
||||
className={cn(
|
||||
"flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-hidden",
|
||||
hideScrollbar ? "no-scrollbar" : "default-scrollbar"
|
||||
|
||||
@@ -331,13 +331,10 @@ const ChatUI = React.memo(
|
||||
return null;
|
||||
})}
|
||||
|
||||
{/* Error banner when last message is user message or error type.
|
||||
Skip for multi-model per-panel errors — those are shown in
|
||||
their own panel, not as a global banner. */}
|
||||
{/* Error banner when last message is user message or error type */}
|
||||
{(((error !== null || loadError !== null) &&
|
||||
messages[messages.length - 1]?.type === "user") ||
|
||||
(messages[messages.length - 1]?.type === "error" &&
|
||||
!messages[messages.length - 1]?.modelDisplayName)) && (
|
||||
messages[messages.length - 1]?.type === "error") && (
|
||||
<div className={`p-4 w-full ${MSG_MAX_W} self-center`}>
|
||||
<ErrorBanner
|
||||
resubmit={onResubmit}
|
||||
|
||||
@@ -86,7 +86,6 @@ export interface AppInputBarProps {
|
||||
deepResearchEnabled: boolean;
|
||||
setPresentingDocument?: (document: MinimalOnyxDocument) => void;
|
||||
toggleDeepResearch: () => void;
|
||||
isMultiModelActive?: boolean;
|
||||
disabled: boolean;
|
||||
ref?: React.Ref<AppInputBarHandle>;
|
||||
// Side panel tab reading
|
||||
@@ -110,7 +109,6 @@ const AppInputBar = React.memo(
|
||||
llmManager,
|
||||
deepResearchEnabled,
|
||||
toggleDeepResearch,
|
||||
isMultiModelActive,
|
||||
setPresentingDocument,
|
||||
disabled,
|
||||
ref,
|
||||
@@ -556,17 +554,12 @@ const AppInputBar = React.memo(
|
||||
) : (
|
||||
showDeepResearch && (
|
||||
<SelectButton
|
||||
disabled={disabled || isMultiModelActive}
|
||||
disabled={disabled}
|
||||
variant="select-light"
|
||||
icon={SvgHourglass}
|
||||
onClick={toggleDeepResearch}
|
||||
state={deepResearchEnabled ? "selected" : "empty"}
|
||||
foldable={!deepResearchEnabled}
|
||||
tooltip={
|
||||
isMultiModelActive
|
||||
? "Deep Research is disabled in multi-model mode"
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
Deep Research
|
||||
</SelectButton>
|
||||
|
||||
@@ -50,7 +50,7 @@ function BifrostModalInternals({
|
||||
const { models, error } = await fetchBifrostModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.BIFROST,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import {
|
||||
@@ -29,8 +29,9 @@ import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { Button, Card, EmptyMessageCard } from "@opal/components";
|
||||
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
import { SvgMinusCircle, SvgPlusCircle, SvgRefreshCw } from "@opal/icons";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
@@ -110,6 +111,95 @@ function ModelConfigurationItem({
|
||||
);
|
||||
}
|
||||
|
||||
interface FetchedModel {
|
||||
name: string;
|
||||
display_name: string;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
}
|
||||
|
||||
function FetchModelsButton({ provider }: { provider: string }) {
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const [isFetching, setIsFetching] = useState(false);
|
||||
const formikProps = useFormikContext<{
|
||||
api_base?: string;
|
||||
api_key?: string;
|
||||
api_version?: string;
|
||||
model_configurations: CustomModelConfiguration[];
|
||||
}>();
|
||||
|
||||
useEffect(() => {
|
||||
return () => abortRef.current?.abort();
|
||||
}, []);
|
||||
|
||||
async function handleFetch() {
|
||||
abortRef.current?.abort();
|
||||
const controller = new AbortController();
|
||||
abortRef.current = controller;
|
||||
setIsFetching(true);
|
||||
try {
|
||||
const response = await fetch("/api/admin/llm/custom/available-models", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider,
|
||||
api_base: formikProps.values.api_base || undefined,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
api_version: formikProps.values.api_version || undefined,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
}
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
const fetched: FetchedModel[] = await response.json();
|
||||
const existing = formikProps.values.model_configurations;
|
||||
const existingNames = new Set(existing.map((m) => m.name));
|
||||
const newModels: CustomModelConfiguration[] = fetched
|
||||
.filter((m) => !existingNames.has(m.name))
|
||||
.map((m) => ({
|
||||
name: m.name,
|
||||
display_name: m.display_name !== m.name ? m.display_name : "",
|
||||
max_input_tokens: m.max_input_tokens,
|
||||
supports_image_input: m.supports_image_input,
|
||||
}));
|
||||
// Replace empty placeholder rows, then merge
|
||||
const nonEmpty = existing.filter((m) => m.name.trim() !== "");
|
||||
formikProps.setFieldValue("model_configurations", [
|
||||
...nonEmpty,
|
||||
...newModels,
|
||||
]);
|
||||
toast.success(`Fetched ${fetched.length} models`);
|
||||
} catch (err) {
|
||||
if (err instanceof DOMException && err.name === "AbortError") return;
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
} finally {
|
||||
if (!controller.signal.aborted) {
|
||||
setIsFetching(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={isFetching ? SimpleLoader : SvgRefreshCw}
|
||||
onClick={handleFetch}
|
||||
disabled={isFetching || !provider}
|
||||
type="button"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function ModelConfigurationList() {
|
||||
const formikProps = useFormikContext<{
|
||||
model_configurations: CustomModelConfiguration[];
|
||||
@@ -222,6 +312,24 @@ function ProviderNameSelect({ disabled }: { disabled?: boolean }) {
|
||||
);
|
||||
}
|
||||
|
||||
function ModelsHeader() {
|
||||
const { values } = useFormikContext<{ provider: string }>();
|
||||
return (
|
||||
<InputLayouts.Horizontal
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
nonInteractive
|
||||
center
|
||||
>
|
||||
{values.provider ? (
|
||||
<FetchModelsButton provider={values.provider} />
|
||||
) : (
|
||||
<div />
|
||||
)}
|
||||
</InputLayouts.Horizontal>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Custom Config Processing ─────────────────────────────────────────────────
|
||||
|
||||
function keyValueListToDict(items: KeyValue[]): Record<string, string> {
|
||||
@@ -424,13 +532,7 @@ export default function CustomModal({
|
||||
<InputLayouts.FieldSeparator />
|
||||
<Section gap={0.5}>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Content
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
widthVariant="full"
|
||||
/>
|
||||
<ModelsHeader />
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<Card padding="sm">
|
||||
|
||||
@@ -52,7 +52,7 @@ function LiteLLMProxyModalInternals({
|
||||
const { models, error } = await fetchLiteLLMProxyModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.LITELLM_PROXY,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
@@ -52,7 +52,7 @@ function OpenRouterModalInternals({
|
||||
const { models, error } = await fetchOpenRouterModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.OPENROUTER,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
75
web/src/sections/modals/llmConfig/getModal.tsx
Normal file
75
web/src/sections/modals/llmConfig/getModal.tsx
Normal file
@@ -0,0 +1,75 @@
|
||||
import { LLMProviderName, LLMProviderView } from "@/interfaces/llm";
|
||||
import AnthropicModal from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import OllamaModal from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import AzureModal from "@/sections/modals/llmConfig/AzureModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
export function getModalForExistingProvider(
|
||||
provider: LLMProviderView,
|
||||
onOpenChange?: (open: boolean) => void,
|
||||
defaultModelName?: string
|
||||
) {
|
||||
const props = {
|
||||
existingLlmProvider: provider,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
};
|
||||
|
||||
const hasCustomConfig = provider.custom_config != null;
|
||||
|
||||
switch (provider.provider) {
|
||||
// These providers don't use custom_config themselves, so a non-null
|
||||
// custom_config means the provider was created via CustomModal.
|
||||
case LLMProviderName.OPENAI:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<OpenAIModal {...props} />
|
||||
);
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<AnthropicModal {...props} />
|
||||
);
|
||||
case LLMProviderName.AZURE:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<AzureModal {...props} />
|
||||
);
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<OpenRouterModal {...props} />
|
||||
);
|
||||
|
||||
// These providers legitimately store settings in custom_config,
|
||||
// so always use their dedicated modals.
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaModal {...props} />;
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIModal {...props} />;
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockModal {...props} />;
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioModal {...props} />;
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...props} />;
|
||||
case LLMProviderName.BIFROST:
|
||||
return <BifrostModal {...props} />;
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...props} />;
|
||||
default:
|
||||
return <CustomModal {...props} />;
|
||||
}
|
||||
}
|
||||
@@ -44,7 +44,11 @@ import useUsers from "@/hooks/useUsers";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { UserRole } from "@/lib/types";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { getProvider } from "@/lib/llmConfig";
|
||||
import {
|
||||
getProviderIcon,
|
||||
getProviderDisplayName,
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
|
||||
// ─── DisplayNameField ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -713,11 +717,9 @@ function ModalWrapperInner({
|
||||
? "No changes to save."
|
||||
: undefined;
|
||||
|
||||
const {
|
||||
icon: providerIcon,
|
||||
companyName: providerDisplayName,
|
||||
productName: providerProductName,
|
||||
} = getProvider(providerName);
|
||||
const providerIcon = getProviderIcon(providerName);
|
||||
const providerDisplayName = getProviderDisplayName(providerName);
|
||||
const providerProductName = getProviderProductName(providerName);
|
||||
|
||||
const title = llmProvider
|
||||
? `Configure "${llmProvider.name}"`
|
||||
|
||||
145
web/src/sections/onboarding/forms/getOnboardingForm.tsx
Normal file
145
web/src/sections/onboarding/forms/getOnboardingForm.tsx
Normal file
@@ -0,0 +1,145 @@
|
||||
import React from "react";
|
||||
import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
LLMProviderFormProps,
|
||||
} from "@/interfaces/llm";
|
||||
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import AnthropicModal from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import OllamaModal from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import AzureModal from "@/sections/modals/llmConfig/AzureModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
// Display info for LLM provider cards - title is the product name, displayName is the company/platform
|
||||
const PROVIDER_DISPLAY_INFO: Record<
|
||||
string,
|
||||
{ title: string; displayName: string }
|
||||
> = {
|
||||
[LLMProviderName.OPENAI]: { title: "GPT", displayName: "OpenAI" },
|
||||
[LLMProviderName.ANTHROPIC]: { title: "Claude", displayName: "Anthropic" },
|
||||
[LLMProviderName.OLLAMA_CHAT]: { title: "Ollama", displayName: "Ollama" },
|
||||
[LLMProviderName.AZURE]: {
|
||||
title: "Azure OpenAI",
|
||||
displayName: "Microsoft Azure Cloud",
|
||||
},
|
||||
[LLMProviderName.BEDROCK]: {
|
||||
title: "Amazon Bedrock",
|
||||
displayName: "AWS",
|
||||
},
|
||||
[LLMProviderName.VERTEX_AI]: {
|
||||
title: "Gemini",
|
||||
displayName: "Google Cloud Vertex AI",
|
||||
},
|
||||
[LLMProviderName.OPENROUTER]: {
|
||||
title: "OpenRouter",
|
||||
displayName: "OpenRouter",
|
||||
},
|
||||
[LLMProviderName.LM_STUDIO]: {
|
||||
title: "LM Studio",
|
||||
displayName: "LM Studio",
|
||||
},
|
||||
[LLMProviderName.LITELLM_PROXY]: {
|
||||
title: "LiteLLM Proxy",
|
||||
displayName: "LiteLLM Proxy",
|
||||
},
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: {
|
||||
title: "OpenAI-Compatible",
|
||||
displayName: "OpenAI-Compatible",
|
||||
},
|
||||
};
|
||||
|
||||
export function getProviderDisplayInfo(providerName: string): {
|
||||
title: string;
|
||||
displayName: string;
|
||||
} {
|
||||
return (
|
||||
PROVIDER_DISPLAY_INFO[providerName] ?? {
|
||||
title: providerName,
|
||||
displayName: providerName,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export interface OnboardingFormProps {
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor;
|
||||
isCustomProvider?: boolean;
|
||||
onboardingState: OnboardingState;
|
||||
onboardingActions: OnboardingActions;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
}
|
||||
|
||||
export function getOnboardingForm({
|
||||
llmDescriptor,
|
||||
isCustomProvider,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
onOpenChange,
|
||||
}: OnboardingFormProps): React.ReactNode {
|
||||
const providerName = isCustomProvider
|
||||
? "custom"
|
||||
: llmDescriptor?.name ?? "custom";
|
||||
|
||||
const sharedProps: LLMProviderFormProps = {
|
||||
variant: "onboarding" as const,
|
||||
shouldMarkAsDefault:
|
||||
(onboardingState?.data.llmProviders ?? []).length === 0,
|
||||
onboardingActions,
|
||||
onOpenChange,
|
||||
onSuccess: () => {
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
},
|
||||
};
|
||||
|
||||
// Handle custom provider
|
||||
if (isCustomProvider || !llmDescriptor) {
|
||||
return <CustomModal {...sharedProps} />;
|
||||
}
|
||||
|
||||
switch (llmDescriptor.name) {
|
||||
case LLMProviderName.OPENAI:
|
||||
return <OpenAIModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return <AnthropicModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.AZURE:
|
||||
return <AzureModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...sharedProps} />;
|
||||
|
||||
default:
|
||||
return <CustomModal {...sharedProps} />;
|
||||
}
|
||||
}
|
||||
@@ -4,29 +4,35 @@ import { memo, useState, useCallback } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button } from "@opal/components";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import LLMProviderCard from "@/sections/onboarding/components/LLMProviderCard";
|
||||
import LLMProviderCard from "../components/LLMProviderCard";
|
||||
import {
|
||||
OnboardingActions,
|
||||
OnboardingState,
|
||||
OnboardingStep,
|
||||
} from "@/interfaces/onboarding";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import { getProvider } from "@/lib/llmConfig";
|
||||
getOnboardingForm,
|
||||
getProviderDisplayInfo,
|
||||
} from "../forms/getOnboardingForm";
|
||||
import { Disabled } from "@opal/core";
|
||||
import ModelIcon from "@/app/admin/configuration/llm/ModelIcon";
|
||||
import { SvgCheckCircle, SvgCpu, SvgExternalLink } from "@opal/icons";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { useLLMProviderOptions } from "@/lib/hooks/useLLMProviderOptions";
|
||||
|
||||
type LLMStepProps = {
|
||||
state: OnboardingState;
|
||||
actions: OnboardingActions;
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
interface SelectedProvider {
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor;
|
||||
isCustomProvider: boolean;
|
||||
}
|
||||
|
||||
function LLMProviderSkeleton() {
|
||||
const LLMProviderSkeleton = () => {
|
||||
return (
|
||||
<div className="flex justify-between h-full w-full p-1 rounded-12 border border-border-01 bg-background-neutral-01 animate-pulse">
|
||||
<div className="flex gap-1 p-1 flex-1 min-w-0">
|
||||
@@ -41,11 +47,12 @@ function LLMProviderSkeleton() {
|
||||
<div className="h-6 w-16 bg-neutral-200 rounded" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
interface StackedProviderIconsProps {
|
||||
type StackedProviderIconsProps = {
|
||||
providers: string[];
|
||||
}
|
||||
};
|
||||
|
||||
const StackedProviderIcons = ({ providers }: StackedProviderIconsProps) => {
|
||||
if (!providers || providers.length === 0) {
|
||||
return null;
|
||||
@@ -82,157 +89,133 @@ const StackedProviderIcons = ({ providers }: StackedProviderIconsProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
interface LLMStepProps {
|
||||
state: OnboardingState;
|
||||
actions: OnboardingActions;
|
||||
disabled?: boolean;
|
||||
}
|
||||
const LLMStep = memo(
|
||||
({
|
||||
state: onboardingState,
|
||||
actions: onboardingActions,
|
||||
disabled,
|
||||
}: LLMStepProps) => {
|
||||
const { llmProviderOptions, isLoading } = useLLMProviderOptions();
|
||||
const llmDescriptors = llmProviderOptions ?? [];
|
||||
const LLMStepInner = ({
|
||||
state: onboardingState,
|
||||
actions: onboardingActions,
|
||||
disabled,
|
||||
}: LLMStepProps) => {
|
||||
const { llmProviderOptions, isLoading } = useLLMProviderOptions();
|
||||
const llmDescriptors = llmProviderOptions ?? [];
|
||||
|
||||
const [selectedProvider, setSelectedProvider] =
|
||||
useState<SelectedProvider | null>(null);
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [selectedProvider, setSelectedProvider] =
|
||||
useState<SelectedProvider | null>(null);
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
|
||||
const handleProviderClick = useCallback(
|
||||
(
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor,
|
||||
isCustomProvider: boolean = false
|
||||
) => {
|
||||
setSelectedProvider({ llmDescriptor, isCustomProvider });
|
||||
setIsModalOpen(true);
|
||||
},
|
||||
[]
|
||||
);
|
||||
const handleProviderClick = useCallback(
|
||||
(
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor,
|
||||
isCustomProvider: boolean = false
|
||||
) => {
|
||||
setSelectedProvider({ llmDescriptor, isCustomProvider });
|
||||
setIsModalOpen(true);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleModalClose = useCallback((open: boolean) => {
|
||||
setIsModalOpen(open);
|
||||
if (!open) {
|
||||
setSelectedProvider(null);
|
||||
}
|
||||
}, []);
|
||||
const handleModalClose = useCallback((open: boolean) => {
|
||||
setIsModalOpen(open);
|
||||
if (!open) {
|
||||
setSelectedProvider(null);
|
||||
}
|
||||
}, []);
|
||||
|
||||
if (
|
||||
onboardingState.currentStep === OnboardingStep.LlmSetup ||
|
||||
onboardingState.currentStep === OnboardingStep.Name
|
||||
) {
|
||||
const providerName = selectedProvider?.isCustomProvider
|
||||
? "custom"
|
||||
: selectedProvider?.llmDescriptor?.name ?? "custom";
|
||||
|
||||
const { Modal: ModalComponent } = getProvider(providerName);
|
||||
|
||||
const modalProps: LLMProviderFormProps = {
|
||||
variant: "onboarding" as const,
|
||||
shouldMarkAsDefault:
|
||||
(onboardingState?.data.llmProviders ?? []).length === 0,
|
||||
onboardingActions,
|
||||
onOpenChange: handleModalClose,
|
||||
onSuccess: () => {
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
},
|
||||
};
|
||||
|
||||
return (
|
||||
<Disabled disabled={disabled} allowClick>
|
||||
<div
|
||||
className="flex flex-col items-center justify-between w-full p-1 rounded-16 border border-border-01 bg-background-tint-00"
|
||||
aria-label="onboarding-llm-step"
|
||||
>
|
||||
<ContentAction
|
||||
icon={SvgCpu}
|
||||
title="Connect your LLM models"
|
||||
description="Onyx supports both self-hosted models and popular providers."
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="lg"
|
||||
rightChildren={
|
||||
<Button
|
||||
disabled={disabled}
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgExternalLink}
|
||||
href="/admin/configuration/llm"
|
||||
if (
|
||||
onboardingState.currentStep === OnboardingStep.LlmSetup ||
|
||||
onboardingState.currentStep === OnboardingStep.Name
|
||||
) {
|
||||
return (
|
||||
<Disabled disabled={disabled} allowClick>
|
||||
<div
|
||||
className="flex flex-col items-center justify-between w-full p-1 rounded-16 border border-border-01 bg-background-tint-00"
|
||||
aria-label="onboarding-llm-step"
|
||||
>
|
||||
<ContentAction
|
||||
icon={SvgCpu}
|
||||
title="Connect your LLM models"
|
||||
description="Onyx supports both self-hosted models and popular providers."
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="lg"
|
||||
rightChildren={
|
||||
<Button
|
||||
disabled={disabled}
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgExternalLink}
|
||||
href="/admin/configuration/llm"
|
||||
>
|
||||
View in Admin Panel
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<Separator />
|
||||
<div className="flex flex-wrap gap-1 [&>*:last-child:nth-child(odd)]:basis-full">
|
||||
{isLoading ? (
|
||||
Array.from({ length: 8 }).map((_, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="basis-[calc(50%-theme(spacing.1)/2)] grow"
|
||||
>
|
||||
View in Admin Panel
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<Separator />
|
||||
<div className="flex flex-wrap gap-1 [&>*:last-child:nth-child(odd)]:basis-full">
|
||||
{isLoading ? (
|
||||
Array.from({ length: 8 }).map((_, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="basis-[calc(50%-theme(spacing.1)/2)] grow"
|
||||
>
|
||||
<LLMProviderSkeleton />
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<>
|
||||
{/* Render the selected provider form */}
|
||||
{selectedProvider && isModalOpen && (
|
||||
<ModalComponent {...modalProps} />
|
||||
)}
|
||||
|
||||
{/* Render provider cards */}
|
||||
{llmDescriptors.map((llmDescriptor) => {
|
||||
const { productName, companyName } = getProvider(
|
||||
llmDescriptor.name
|
||||
);
|
||||
return (
|
||||
<div
|
||||
key={llmDescriptor.name}
|
||||
className="basis-[calc(50%-theme(spacing.1)/2)] grow"
|
||||
>
|
||||
<LLMProviderCard
|
||||
title={productName}
|
||||
subtitle={companyName}
|
||||
providerName={llmDescriptor.name}
|
||||
disabled={disabled}
|
||||
isConnected={onboardingState.data.llmProviders?.some(
|
||||
(provider) => provider === llmDescriptor.name
|
||||
)}
|
||||
onClick={() =>
|
||||
handleProviderClick(llmDescriptor, false)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
<LLMProviderSkeleton />
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<>
|
||||
{/* Render the selected provider form */}
|
||||
{selectedProvider &&
|
||||
isModalOpen &&
|
||||
getOnboardingForm({
|
||||
llmDescriptor: selectedProvider.llmDescriptor,
|
||||
isCustomProvider: selectedProvider.isCustomProvider,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
onOpenChange: handleModalClose,
|
||||
})}
|
||||
|
||||
{/* Custom provider card */}
|
||||
<div className="basis-[calc(50%-theme(spacing.1)/2)] grow">
|
||||
<LLMProviderCard
|
||||
title="Custom LLM Provider"
|
||||
subtitle="LiteLLM Compatible APIs"
|
||||
disabled={disabled}
|
||||
isConnected={onboardingState.data.llmProviders?.some(
|
||||
(provider) => provider === "custom"
|
||||
)}
|
||||
onClick={() => handleProviderClick(undefined, true)}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Disabled>
|
||||
);
|
||||
}
|
||||
{/* Render provider cards */}
|
||||
{llmDescriptors.map((llmDescriptor) => {
|
||||
const displayInfo = getProviderDisplayInfo(
|
||||
llmDescriptor.name
|
||||
);
|
||||
return (
|
||||
<div
|
||||
key={llmDescriptor.name}
|
||||
className="basis-[calc(50%-theme(spacing.1)/2)] grow"
|
||||
>
|
||||
<LLMProviderCard
|
||||
title={displayInfo.title}
|
||||
subtitle={displayInfo.displayName}
|
||||
providerName={llmDescriptor.name}
|
||||
disabled={disabled}
|
||||
isConnected={onboardingState.data.llmProviders?.some(
|
||||
(provider) => provider === llmDescriptor.name
|
||||
)}
|
||||
onClick={() =>
|
||||
handleProviderClick(llmDescriptor, false)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Custom provider card */}
|
||||
<div className="basis-[calc(50%-theme(spacing.1)/2)] grow">
|
||||
<LLMProviderCard
|
||||
title="Custom LLM Provider"
|
||||
subtitle="LiteLLM Compatible APIs"
|
||||
disabled={disabled}
|
||||
isConnected={onboardingState.data.llmProviders?.some(
|
||||
(provider) => provider === "custom"
|
||||
)}
|
||||
onClick={() => handleProviderClick(undefined, true)}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Disabled>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
@@ -261,7 +244,7 @@ const LLMStep = memo(
|
||||
</button>
|
||||
);
|
||||
}
|
||||
);
|
||||
LLMStep.displayName = "LLMStep";
|
||||
};
|
||||
|
||||
const LLMStep = memo(LLMStepInner);
|
||||
export default LLMStep;
|
||||
|
||||
Reference in New Issue
Block a user