mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-12 10:22:42 +00:00
Compare commits
1 Commits
v3.2.2
...
jamison/fe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c709a4fb6 |
@@ -9,6 +9,7 @@ repos:
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
@@ -17,7 +18,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
@@ -30,7 +31,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
@@ -43,7 +44,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
@@ -56,7 +57,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--group",
|
||||
"--extra",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
|
||||
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@@ -531,7 +531,8 @@
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"sync"
|
||||
"sync",
|
||||
"--all-extras"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
|
||||
@@ -117,7 +117,7 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required Python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
@@ -13,7 +13,6 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -108,13 +107,12 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role
|
||||
and the anonymous system user).
|
||||
|
||||
Only human accounts count toward seat limits.
|
||||
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
|
||||
anonymous system user are excluded. BOT (Slack users) ARE counted
|
||||
because they represent real humans and get upgraded to STANDARD
|
||||
when they log in via web.
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
@@ -131,7 +129,6 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -11,8 +11,6 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -24,7 +22,6 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -68,25 +65,12 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -225,37 +209,12 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -60,10 +60,8 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -257,13 +255,15 @@ def _bulk_fetch_request(
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def _bulk_fetch_batch(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
|
||||
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
return _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
@@ -277,25 +277,12 @@ def _bulk_fetch_batch(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
|
||||
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
raw_issues: list[dict[str, Any]] = []
|
||||
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
|
||||
try:
|
||||
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -7,14 +6,6 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -50,6 +49,7 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,6 +58,7 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -420,94 +421,6 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -519,6 +432,7 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -748,6 +662,7 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -780,37 +695,62 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
"""Build the thread text from messages."""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
for msg in replies:
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -1036,16 +976,7 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -1062,16 +993,8 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1087,7 +1010,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
)
|
||||
]
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -639,38 +638,12 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
) -> list[str]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -695,15 +668,6 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -720,9 +684,7 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -740,8 +702,7 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
search_queries = [
|
||||
return [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -96,32 +96,6 @@ def _truncate_description(description: str | None, max_length: int = 500) -> str
|
||||
return description[: max_length - 3] + "..."
|
||||
|
||||
|
||||
# TODO: Replace mask-comparison approach with an explicit Unset sentinel from the
|
||||
# frontend indicating whether each credential field was actually modified. The current
|
||||
# approach is brittle (e.g. short credentials produce a fixed-length mask that could
|
||||
# collide) and mutates request values, which is surprising. The frontend should signal
|
||||
# "unchanged" vs "new value" directly rather than relying on masked-string equality.
|
||||
def _restore_masked_oauth_credentials(
|
||||
request_client_id: str | None,
|
||||
request_client_secret: str | None,
|
||||
existing_client: OAuthClientInformationFull,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""If the frontend sent back masked credentials, restore the real stored values."""
|
||||
if (
|
||||
request_client_id
|
||||
and existing_client.client_id
|
||||
and request_client_id == mask_string(existing_client.client_id)
|
||||
):
|
||||
request_client_id = existing_client.client_id
|
||||
if (
|
||||
request_client_secret
|
||||
and existing_client.client_secret
|
||||
and request_client_secret == mask_string(existing_client.client_secret)
|
||||
):
|
||||
request_client_secret = existing_client.client_secret
|
||||
return request_client_id, request_client_secret
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mcp")
|
||||
admin_router = APIRouter(prefix="/admin/mcp")
|
||||
STATE_TTL_SECONDS = 60 * 5 # 5 minutes
|
||||
@@ -418,26 +392,6 @@ async def _connect_oauth(
|
||||
detail=f"Server was configured with authentication type {auth_type_str}",
|
||||
)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so we don't overwrite them with masks.
|
||||
if mcp_server.admin_connection_config:
|
||||
existing_data = extract_connection_data(
|
||||
mcp_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
existing_client_raw = existing_data.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if existing_client_raw:
|
||||
existing_client = OAuthClientInformationFull.model_validate(
|
||||
existing_client_raw
|
||||
)
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
existing_client,
|
||||
)
|
||||
|
||||
# Create admin config with client info if provided
|
||||
config_data = MCPConnectionData(headers={})
|
||||
if request.oauth_client_id and request.oauth_client_secret:
|
||||
@@ -1402,19 +1356,6 @@ def _upsert_mcp_server(
|
||||
if client_info_raw:
|
||||
client_info = OAuthClientInformationFull.model_validate(client_info_raw)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so the comparison below sees no change
|
||||
# and the credentials aren't overwritten with masked strings.
|
||||
if client_info and request.auth_type == MCPAuthenticationType.OAUTH:
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
client_info,
|
||||
)
|
||||
|
||||
changing_connection_config = (
|
||||
not mcp_server.admin_connection_config
|
||||
or (
|
||||
|
||||
@@ -47,6 +47,8 @@ from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import get_bedrock_token_limit
|
||||
from onyx.llm.utils import get_llm_contextual_cost
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import litellm_thinks_model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
@@ -62,6 +64,8 @@ from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import BifrostFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BifrostModelsRequest
|
||||
from onyx.server.manage.llm.models import CustomProviderModelResponse
|
||||
from onyx.server.manage.llm.models import CustomProviderModelsRequest
|
||||
from onyx.server.manage.llm.models import CustomProviderOption
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
@@ -111,43 +115,6 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _resolve_api_key(
|
||||
api_key: str | None,
|
||||
provider_name: str | None,
|
||||
api_base: str | None,
|
||||
db_session: Session,
|
||||
) -> str | None:
|
||||
"""Return the real API key for model-fetch endpoints.
|
||||
|
||||
When editing an existing provider the form value is masked (e.g.
|
||||
``sk-a****b1c2``). If *provider_name* is supplied we can look up
|
||||
the unmasked key from the database so the external request succeeds.
|
||||
|
||||
The stored key is only returned when the request's *api_base*
|
||||
matches the value stored in the database.
|
||||
"""
|
||||
if not provider_name:
|
||||
return api_key
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.api_key:
|
||||
# Normalise both URLs before comparing so trailing-slash
|
||||
# differences don't cause a false mismatch.
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
request_base = (api_base or "").strip().rstrip("/")
|
||||
if stored_base != request_base:
|
||||
return api_key
|
||||
|
||||
stored_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
# Only resolve when the incoming value is the masked form of the
|
||||
# stored key — i.e. the user hasn't typed a new key.
|
||||
if api_key and api_key == _mask_string(stored_key):
|
||||
return stored_key
|
||||
return api_key
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
@@ -313,6 +280,158 @@ def fetch_custom_provider_names(
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/custom/available-models")
|
||||
def fetch_custom_provider_models(
|
||||
request: CustomProviderModelsRequest,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> list[CustomProviderModelResponse]:
|
||||
"""Fetch models for a custom provider.
|
||||
|
||||
When ``api_base`` is provided the endpoint hits the provider's
|
||||
OpenAI-compatible ``/v1/models`` (or ``/{api_version}/models``) to
|
||||
discover live models. Otherwise it falls back to the static list
|
||||
that LiteLLM ships for the given provider slug.
|
||||
|
||||
In both cases the response is enriched with metadata from LiteLLM
|
||||
(display name, max input tokens, vision support) when available.
|
||||
"""
|
||||
if request.api_base:
|
||||
return _fetch_custom_models_from_api(
|
||||
provider=request.provider,
|
||||
api_base=request.api_base,
|
||||
api_key=request.api_key,
|
||||
api_version=request.api_version,
|
||||
)
|
||||
|
||||
return _fetch_custom_models_from_litellm(request.provider)
|
||||
|
||||
|
||||
def _enrich_custom_model(
|
||||
name: str,
|
||||
provider: str,
|
||||
*,
|
||||
api_display_name: str | None = None,
|
||||
api_max_input_tokens: int | None = None,
|
||||
api_supports_image_input: bool | None = None,
|
||||
) -> CustomProviderModelResponse:
|
||||
"""Build a ``CustomProviderModelResponse`` enriched with LiteLLM metadata.
|
||||
|
||||
Values explicitly provided by the source API take precedence; LiteLLM
|
||||
metadata is used as a fallback.
|
||||
"""
|
||||
from onyx.llm.model_name_parser import parse_litellm_model_name
|
||||
|
||||
# LiteLLM keys are typically "provider/model"
|
||||
litellm_key = f"{provider}/{name}" if not name.startswith(f"{provider}/") else name
|
||||
parsed = parse_litellm_model_name(litellm_key)
|
||||
|
||||
# display_name: prefer API-provided name, then LiteLLM enrichment, then raw name
|
||||
if api_display_name and api_display_name != name:
|
||||
display_name = api_display_name
|
||||
else:
|
||||
display_name = parsed.display_name or name
|
||||
|
||||
# max_input_tokens: prefer API value, then LiteLLM lookup
|
||||
if api_max_input_tokens is not None:
|
||||
max_input_tokens: int | None = api_max_input_tokens
|
||||
else:
|
||||
try:
|
||||
max_input_tokens = get_max_input_tokens(name, provider)
|
||||
except Exception:
|
||||
max_input_tokens = None
|
||||
|
||||
# supports_image_input: prefer API value, then LiteLLM inference
|
||||
if api_supports_image_input is not None:
|
||||
supports_image = api_supports_image_input
|
||||
else:
|
||||
supports_image = litellm_thinks_model_supports_image_input(name, provider)
|
||||
|
||||
return CustomProviderModelResponse(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=supports_image,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_custom_models_from_api(
|
||||
provider: str,
|
||||
api_base: str,
|
||||
api_key: str | None,
|
||||
api_version: str | None,
|
||||
) -> list[CustomProviderModelResponse]:
|
||||
"""Hit an OpenAI-compatible ``/v1/models`` (or versioned variant)."""
|
||||
cleaned = api_base.strip().rstrip("/")
|
||||
if api_version:
|
||||
url = f"{cleaned}/{api_version.strip().strip('/')}/models"
|
||||
elif cleaned.endswith("/v1"):
|
||||
url = f"{cleaned}/models"
|
||||
else:
|
||||
url = f"{cleaned}/v1/models"
|
||||
|
||||
response_json = _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="Custom provider",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from the provider's API.",
|
||||
)
|
||||
|
||||
results: list[CustomProviderModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
results.append(
|
||||
_enrich_custom_model(
|
||||
model_id,
|
||||
provider,
|
||||
api_display_name=model.get("name"),
|
||||
api_max_input_tokens=model.get("context_length"),
|
||||
api_supports_image_input=infer_vision_support(model_id),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse custom provider model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from the provider's API.",
|
||||
)
|
||||
|
||||
return sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
|
||||
def _fetch_custom_models_from_litellm(
|
||||
provider: str,
|
||||
) -> list[CustomProviderModelResponse]:
|
||||
"""Fall back to litellm's static ``models_by_provider`` mapping."""
|
||||
import litellm
|
||||
|
||||
model_names = litellm.models_by_provider.get(provider)
|
||||
if model_names is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Unknown provider: {provider}",
|
||||
)
|
||||
return sorted(
|
||||
(_enrich_custom_model(name, provider) for name in model_names),
|
||||
key=lambda m: m.name.lower(),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/built-in/options")
|
||||
def fetch_llm_options(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
@@ -1211,17 +1330,16 @@ def get_ollama_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
"""Perform GET to OpenRouter /models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/models"
|
||||
headers: dict[str, str] = {
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
# Optional headers recommended by OpenRouter for attribution
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
@@ -1244,12 +1362,8 @@ def get_openrouter_available_models(
|
||||
Parses id, name (display), context_length, and architecture.input_modalities.
|
||||
"""
|
||||
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openrouter_models_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
data = response_json.get("data", [])
|
||||
@@ -1342,18 +1456,13 @@ def get_lm_studio_available_models(
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
# Only do so when the api_base matches what is stored.
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
if stored_base == cleaned_api_base:
|
||||
api_key = existing_provider.custom_config.get(
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
)
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
@@ -1437,12 +1546,8 @@ def get_litellm_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=api_key, api_base=request.api_base
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1499,7 +1604,7 @@ def get_litellm_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
@@ -1574,12 +1679,8 @@ def get_bifrost_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1668,12 +1769,8 @@ def get_openai_compatible_server_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
|
||||
@@ -477,6 +477,21 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
# Custom provider dynamic models fetch
|
||||
class CustomProviderModelsRequest(BaseModel):
|
||||
provider: str # LiteLLM provider slug (e.g. "deepseek", "fireworks_ai")
|
||||
api_base: str | None = None # If set, fetches live models via /v1/models
|
||||
api_key: str | None = None
|
||||
api_version: str | None = None # If set, used to construct the models URL
|
||||
|
||||
|
||||
class CustomProviderModelResponse(BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
# OpenAI Compatible dynamic models fetch
|
||||
class OpenAICompatibleModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
|
||||
10
backend/pyproject.toml
Normal file
10
backend/pyproject.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[project]
|
||||
name = "onyx-backend"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"onyx[backend,dev,ee]",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
onyx = { workspace = true }
|
||||
@@ -46,11 +46,11 @@ curl -LsSf https://astral.py/uv/install.sh | sh
|
||||
|
||||
1. Edit `pyproject.toml`
|
||||
2. Add/update/remove dependencies in the appropriate section:
|
||||
- `[dependency-groups]` for dev tools
|
||||
- `[project.dependencies]` for **shared** dependencies (used by both backend and model_server)
|
||||
- `[dependency-groups.backend]` for backend-only dependencies
|
||||
- `[dependency-groups.dev]` for dev tools
|
||||
- `[dependency-groups.ee]` for EE features
|
||||
- `[dependency-groups.model_server]` for model_server-only dependencies (ML packages)
|
||||
- `[project.optional-dependencies.backend]` for backend-only dependencies
|
||||
- `[project.optional-dependencies.model_server]` for model_server-only dependencies (ML packages)
|
||||
- `[project.optional-dependencies.ee]` for EE features
|
||||
3. Commit your changes - pre-commit hooks will automatically regenerate the lock file and requirements
|
||||
|
||||
### 3. Generating Lock File and Requirements
|
||||
@@ -64,10 +64,10 @@ To manually regenerate:
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --group backend -o backend/requirements/default.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --extra backend -o backend/requirements/default.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --group dev -o backend/requirements/dev.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --group ee -o backend/requirements/ee.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --group model_server -o backend/requirements/model_server.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --extra ee -o backend/requirements/ee.txt
|
||||
uv export --no-emit-project --no-default-groups --no-hashes --extra model_server -o backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
### 4. Installing Dependencies
|
||||
@@ -76,14 +76,30 @@ If enabled, all packages are installed automatically by the `uv-sync` pre-commit
|
||||
branches or pulling new changes.
|
||||
|
||||
```bash
|
||||
# For development (most common) — installs shared + backend + dev + ee
|
||||
uv sync
|
||||
# For everything (most common)
|
||||
uv sync --all-extras
|
||||
|
||||
# For backend production only (shared + backend dependencies)
|
||||
uv sync --no-default-groups --group backend
|
||||
# For backend production (shared + backend dependencies)
|
||||
uv sync --extra backend
|
||||
|
||||
# For backend development (shared + backend + dev tools)
|
||||
uv sync --extra backend --extra dev
|
||||
|
||||
# For backend with EE (shared + backend + ee)
|
||||
uv sync --extra backend --extra ee
|
||||
|
||||
# For model server (shared + model_server, NO backend deps!)
|
||||
uv sync --no-default-groups --group model_server
|
||||
uv sync --extra model_server
|
||||
```
|
||||
|
||||
`uv` aggressively [ignores active virtual environments](https://docs.astral.sh/uv/concepts/projects/config/#project-environment-path) and prefers the root virtual environment.
|
||||
When working in workspace packages, be sure to pass `--active` when syncing the virtual environment:
|
||||
|
||||
```bash
|
||||
cd backend/
|
||||
source .venv/bin/activate
|
||||
uv sync --active
|
||||
uv run --active ...
|
||||
```
|
||||
|
||||
### 5. Upgrading Dependencies
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --group backend -o backend/requirements/default.txt
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --extra backend -o backend/requirements/default.txt
|
||||
agent-client-protocol==0.7.1
|
||||
# via onyx
|
||||
aioboto3==15.1.0
|
||||
@@ -19,6 +19,7 @@ aiohttp==3.13.4
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
# litellm
|
||||
# onyx
|
||||
# voyageai
|
||||
aioitertools==0.13.0
|
||||
# via aiobotocore
|
||||
@@ -27,6 +28,7 @@ aiolimiter==1.2.1
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
alembic==1.10.4
|
||||
# via onyx
|
||||
amqp==5.3.1
|
||||
# via kombu
|
||||
annotated-doc==0.0.4
|
||||
@@ -49,10 +51,13 @@ argon2-cffi==23.1.0
|
||||
argon2-cffi-bindings==25.1.0
|
||||
# via argon2-cffi
|
||||
asana==5.0.8
|
||||
# via onyx
|
||||
async-timeout==5.0.1 ; python_full_version < '3.11.3'
|
||||
# via redis
|
||||
asyncpg==0.30.0
|
||||
# via onyx
|
||||
atlassian-python-api==3.41.16
|
||||
# via onyx
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
@@ -63,6 +68,7 @@ attrs==25.4.0
|
||||
authlib==1.6.9
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via onyx
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
backoff==2.2.1
|
||||
@@ -80,6 +86,7 @@ beautifulsoup4==4.12.3
|
||||
# atlassian-python-api
|
||||
# markdownify
|
||||
# markitdown
|
||||
# onyx
|
||||
# unstructured
|
||||
billiard==4.2.3
|
||||
# via celery
|
||||
@@ -87,7 +94,9 @@ boto3==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
# cohere
|
||||
# onyx
|
||||
boto3-stubs==1.39.11
|
||||
# via onyx
|
||||
botocore==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -96,6 +105,7 @@ botocore==1.39.11
|
||||
botocore-stubs==1.40.74
|
||||
# via boto3-stubs
|
||||
braintrust==0.3.9
|
||||
# via onyx
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
bytecode==0.17.0
|
||||
@@ -105,6 +115,7 @@ cachetools==6.2.2
|
||||
caio==0.9.25
|
||||
# via aiofile
|
||||
celery==5.5.1
|
||||
# via onyx
|
||||
certifi==2025.11.12
|
||||
# via
|
||||
# asana
|
||||
@@ -123,6 +134,7 @@ cffi==2.0.0
|
||||
# pynacl
|
||||
# zstandard
|
||||
chardet==5.2.0
|
||||
# via onyx
|
||||
charset-normalizer==3.4.4
|
||||
# via
|
||||
# htmldate
|
||||
@@ -134,6 +146,7 @@ charset-normalizer==3.4.4
|
||||
chevron==0.14.0
|
||||
# via braintrust
|
||||
chonkie==1.0.10
|
||||
# via onyx
|
||||
claude-agent-sdk==0.1.19
|
||||
# via onyx
|
||||
click==8.3.1
|
||||
@@ -188,12 +201,15 @@ cryptography==46.0.6
|
||||
cyclopts==4.2.4
|
||||
# via fastmcp
|
||||
dask==2026.1.1
|
||||
# via distributed
|
||||
# via
|
||||
# distributed
|
||||
# onyx
|
||||
dataclasses-json==0.6.7
|
||||
# via unstructured
|
||||
dateparser==1.2.2
|
||||
# via htmldate
|
||||
ddtrace==3.10.0
|
||||
# via onyx
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
defusedxml==0.7.1
|
||||
@@ -207,6 +223,7 @@ deprecated==1.3.1
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distributed==2026.1.1
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via
|
||||
# openai
|
||||
@@ -218,6 +235,7 @@ docstring-parser==0.17.0
|
||||
docutils==0.22.3
|
||||
# via rich-rst
|
||||
dropbox==12.0.2
|
||||
# via onyx
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
email-validator==2.2.0
|
||||
@@ -233,6 +251,7 @@ et-xmlfile==2.0.0
|
||||
events==0.5
|
||||
# via opensearch-py
|
||||
exa-py==1.15.4
|
||||
# via onyx
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# braintrust
|
||||
@@ -243,16 +262,23 @@ fastapi==0.133.1
|
||||
# fastapi-users
|
||||
# onyx
|
||||
fastapi-limiter==0.1.6
|
||||
# via onyx
|
||||
fastapi-users==15.0.4
|
||||
# via fastapi-users-db-sqlalchemy
|
||||
# via
|
||||
# fastapi-users-db-sqlalchemy
|
||||
# onyx
|
||||
fastapi-users-db-sqlalchemy==7.0.0
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
fastmcp==3.2.0
|
||||
# via onyx
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
filelock==3.20.3
|
||||
# via huggingface-hub
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onyx
|
||||
filetype==1.2.0
|
||||
# via unstructured
|
||||
flatbuffers==25.9.23
|
||||
@@ -272,6 +298,7 @@ gitpython==3.1.45
|
||||
google-api-core==2.28.1
|
||||
# via google-api-python-client
|
||||
google-api-python-client==2.86.0
|
||||
# via onyx
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
@@ -281,8 +308,11 @@ google-auth==2.48.0
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-auth-httplib2==0.1.0
|
||||
# via google-api-python-client
|
||||
# via
|
||||
# google-api-python-client
|
||||
# onyx
|
||||
google-auth-oauthlib==1.0.0
|
||||
# via onyx
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
googleapis-common-protos==1.72.0
|
||||
@@ -310,6 +340,7 @@ htmldate==1.9.1
|
||||
httpcore==1.0.9
|
||||
# via
|
||||
# httpx
|
||||
# onyx
|
||||
# unstructured-client
|
||||
httplib2==0.31.0
|
||||
# via
|
||||
@@ -326,16 +357,21 @@ httpx==0.28.1
|
||||
# langsmith
|
||||
# litellm
|
||||
# mcp
|
||||
# onyx
|
||||
# openai
|
||||
# unstructured-client
|
||||
httpx-oauth==0.15.1
|
||||
# via onyx
|
||||
httpx-sse==0.4.3
|
||||
# via
|
||||
# cohere
|
||||
# mcp
|
||||
hubspot-api-client==11.1.0
|
||||
# via onyx
|
||||
huggingface-hub==0.35.3
|
||||
# via tokenizers
|
||||
# via
|
||||
# onyx
|
||||
# tokenizers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
hyperframe==6.1.0
|
||||
@@ -354,7 +390,9 @@ importlib-metadata==8.7.0
|
||||
# litellm
|
||||
# opentelemetry-api
|
||||
inflection==0.5.1
|
||||
# via pyairtable
|
||||
# via
|
||||
# onyx
|
||||
# pyairtable
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
isodate==0.7.2
|
||||
@@ -376,6 +414,7 @@ jinja2==3.1.6
|
||||
# distributed
|
||||
# litellm
|
||||
jira==3.10.5
|
||||
# via onyx
|
||||
jiter==0.12.0
|
||||
# via openai
|
||||
jmespath==1.0.1
|
||||
@@ -391,7 +430,9 @@ jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
jsonref==1.1.0
|
||||
# via fastmcp
|
||||
# via
|
||||
# fastmcp
|
||||
# onyx
|
||||
jsonschema==4.25.1
|
||||
# via
|
||||
# litellm
|
||||
@@ -409,12 +450,15 @@ kombu==5.5.4
|
||||
kubernetes==31.0.0
|
||||
# via onyx
|
||||
langchain-core==1.2.22
|
||||
# via onyx
|
||||
langdetect==1.0.9
|
||||
# via unstructured
|
||||
langfuse==3.10.0
|
||||
# via onyx
|
||||
langsmith==0.3.45
|
||||
# via langchain-core
|
||||
lazy-imports==1.0.1
|
||||
# via onyx
|
||||
legacy-cgi==2.6.4 ; python_full_version >= '3.13'
|
||||
# via ddtrace
|
||||
litellm==1.81.6
|
||||
@@ -429,6 +473,7 @@ lxml==5.3.0
|
||||
# justext
|
||||
# lxml-html-clean
|
||||
# markitdown
|
||||
# onyx
|
||||
# python-docx
|
||||
# python-pptx
|
||||
# python3-saml
|
||||
@@ -443,7 +488,9 @@ magika==0.6.3
|
||||
makefun==1.16.0
|
||||
# via fastapi-users
|
||||
mako==1.2.4
|
||||
# via alembic
|
||||
# via
|
||||
# alembic
|
||||
# onyx
|
||||
mammoth==1.11.0
|
||||
# via markitdown
|
||||
markdown-it-py==4.0.0
|
||||
@@ -451,6 +498,7 @@ markdown-it-py==4.0.0
|
||||
markdownify==1.2.2
|
||||
# via markitdown
|
||||
markitdown==0.1.2
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
@@ -464,9 +512,11 @@ mcp==1.26.0
|
||||
# via
|
||||
# claude-agent-sdk
|
||||
# fastmcp
|
||||
# onyx
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==3.2.0
|
||||
# via onyx
|
||||
more-itertools==10.8.0
|
||||
# via
|
||||
# jaraco-classes
|
||||
@@ -475,10 +525,13 @@ more-itertools==10.8.0
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
msal==1.34.0
|
||||
# via office365-rest-python-client
|
||||
# via
|
||||
# office365-rest-python-client
|
||||
# onyx
|
||||
msgpack==1.1.2
|
||||
# via distributed
|
||||
msoffcrypto-tool==5.4.2
|
||||
# via onyx
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -495,6 +548,7 @@ mypy-extensions==1.0.0
|
||||
# mypy
|
||||
# typing-inspect
|
||||
nest-asyncio==1.6.0
|
||||
# via onyx
|
||||
nltk==3.9.4
|
||||
# via unstructured
|
||||
numpy==2.4.1
|
||||
@@ -509,8 +563,10 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# atlassian-python-api
|
||||
# kubernetes
|
||||
# onyx
|
||||
# requests-oauthlib
|
||||
office365-rest-python-client==2.6.2
|
||||
# via onyx
|
||||
olefile==0.47
|
||||
# via
|
||||
# msoffcrypto-tool
|
||||
@@ -526,11 +582,15 @@ openai==2.14.0
|
||||
openapi-pydantic==0.5.1
|
||||
# via fastmcp
|
||||
openinference-instrumentation==0.1.42
|
||||
# via onyx
|
||||
openinference-semantic-conventions==0.1.25
|
||||
# via openinference-instrumentation
|
||||
openpyxl==3.0.10
|
||||
# via markitdown
|
||||
# via
|
||||
# markitdown
|
||||
# onyx
|
||||
opensearch-py==3.0.0
|
||||
# via onyx
|
||||
opentelemetry-api==1.39.1
|
||||
# via
|
||||
# ddtrace
|
||||
@@ -546,6 +606,7 @@ opentelemetry-exporter-otlp-proto-http==1.39.1
|
||||
# via langfuse
|
||||
opentelemetry-proto==1.39.1
|
||||
# via
|
||||
# onyx
|
||||
# opentelemetry-exporter-otlp-proto-common
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-sdk==1.39.1
|
||||
@@ -579,6 +640,7 @@ parameterized==0.9.0
|
||||
partd==1.4.2
|
||||
# via dask
|
||||
passlib==1.7.4
|
||||
# via onyx
|
||||
pathable==0.4.4
|
||||
# via jsonschema-path
|
||||
pdfminer-six==20251107
|
||||
@@ -590,7 +652,9 @@ platformdirs==4.5.0
|
||||
# fastmcp
|
||||
# zeep
|
||||
playwright==1.55.0
|
||||
# via pytest-playwright
|
||||
# via
|
||||
# onyx
|
||||
# pytest-playwright
|
||||
pluggy==1.6.0
|
||||
# via pytest
|
||||
ply==3.11
|
||||
@@ -620,9 +684,12 @@ protobuf==6.33.5
|
||||
psutil==7.1.3
|
||||
# via
|
||||
# distributed
|
||||
# onyx
|
||||
# unstructured
|
||||
psycopg2-binary==2.9.9
|
||||
# via onyx
|
||||
puremagic==1.28
|
||||
# via onyx
|
||||
pwdlib==0.3.0
|
||||
# via fastapi-users
|
||||
py==1.11.0
|
||||
@@ -630,6 +697,7 @@ py==1.11.0
|
||||
py-key-value-aio==0.4.4
|
||||
# via fastmcp
|
||||
pyairtable==3.0.1
|
||||
# via onyx
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
@@ -639,6 +707,7 @@ pyasn1-modules==0.4.2
|
||||
pycparser==2.23 ; implementation_name != 'PyPy'
|
||||
# via cffi
|
||||
pycryptodome==3.19.1
|
||||
# via onyx
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# agent-client-protocol
|
||||
@@ -665,6 +734,7 @@ pydantic-settings==2.12.0
|
||||
pyee==13.0.0
|
||||
# via playwright
|
||||
pygithub==2.5.0
|
||||
# via onyx
|
||||
pygments==2.20.0
|
||||
# via rich
|
||||
pyjwt==2.12.0
|
||||
@@ -675,13 +745,17 @@ pyjwt==2.12.0
|
||||
# pygithub
|
||||
# simple-salesforce
|
||||
pympler==1.1
|
||||
# via onyx
|
||||
pynacl==1.6.2
|
||||
# via pygithub
|
||||
pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.9.2
|
||||
# via unstructured-client
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
pyperclip==1.11.0
|
||||
# via fastmcp
|
||||
pyreadline3==3.5.4 ; sys_platform == 'win32'
|
||||
@@ -694,7 +768,9 @@ pytest==8.3.5
|
||||
pytest-base-url==2.1.0
|
||||
# via pytest-playwright
|
||||
pytest-mock==3.12.0
|
||||
# via onyx
|
||||
pytest-playwright==0.7.0
|
||||
# via onyx
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -705,9 +781,11 @@ python-dateutil==2.8.2
|
||||
# htmldate
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
# onyx
|
||||
# opensearch-py
|
||||
# pandas
|
||||
python-docx==1.1.2
|
||||
# via onyx
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
# braintrust
|
||||
@@ -715,8 +793,10 @@ python-dotenv==1.1.1
|
||||
# litellm
|
||||
# magika
|
||||
# mcp
|
||||
# onyx
|
||||
# pydantic-settings
|
||||
python-gitlab==5.6.0
|
||||
# via onyx
|
||||
python-http-client==3.3.7
|
||||
# via sendgrid
|
||||
python-iso639==2025.11.16
|
||||
@@ -727,15 +807,19 @@ python-multipart==0.0.22
|
||||
# via
|
||||
# fastapi-users
|
||||
# mcp
|
||||
# onyx
|
||||
python-oxmsg==0.0.2
|
||||
# via unstructured
|
||||
python-pptx==0.6.23
|
||||
# via markitdown
|
||||
# via
|
||||
# markitdown
|
||||
# onyx
|
||||
python-slugify==8.0.4
|
||||
# via
|
||||
# braintrust
|
||||
# pytest-playwright
|
||||
python3-saml==1.15.0
|
||||
# via onyx
|
||||
pytz==2025.2
|
||||
# via
|
||||
# dateparser
|
||||
@@ -743,6 +827,7 @@ pytz==2025.2
|
||||
# pandas
|
||||
# zeep
|
||||
pywikibot==9.0.0
|
||||
# via onyx
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via
|
||||
# mcp
|
||||
@@ -759,9 +844,13 @@ pyyaml==6.0.3
|
||||
# kubernetes
|
||||
# langchain-core
|
||||
rapidfuzz==3.13.0
|
||||
# via unstructured
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
redis==5.0.8
|
||||
# via fastapi-limiter
|
||||
# via
|
||||
# fastapi-limiter
|
||||
# onyx
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# jsonschema
|
||||
@@ -792,6 +881,7 @@ requests==2.33.0
|
||||
# matrix-client
|
||||
# msal
|
||||
# office365-rest-python-client
|
||||
# onyx
|
||||
# opensearch-py
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
# pyairtable
|
||||
@@ -817,6 +907,7 @@ requests-oauthlib==1.3.1
|
||||
# google-auth-oauthlib
|
||||
# jira
|
||||
# kubernetes
|
||||
# onyx
|
||||
requests-toolbelt==1.0.0
|
||||
# via
|
||||
# jira
|
||||
@@ -827,6 +918,7 @@ requests-toolbelt==1.0.0
|
||||
retry==0.9.2
|
||||
# via onyx
|
||||
rfc3986==1.5.0
|
||||
# via onyx
|
||||
rich==14.2.0
|
||||
# via
|
||||
# cyclopts
|
||||
@@ -846,12 +938,15 @@ s3transfer==0.13.1
|
||||
secretstorage==3.5.0 ; sys_platform == 'linux'
|
||||
# via keyring
|
||||
sendgrid==6.12.5
|
||||
# via onyx
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
shapely==2.0.6
|
||||
# via onyx
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
simple-salesforce==1.12.6
|
||||
# via onyx
|
||||
six==1.17.0
|
||||
# via
|
||||
# asana
|
||||
@@ -866,6 +961,7 @@ six==1.17.0
|
||||
# python-dateutil
|
||||
# stone
|
||||
slack-sdk==3.20.2
|
||||
# via onyx
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
@@ -880,6 +976,7 @@ sqlalchemy==2.0.15
|
||||
# via
|
||||
# alembic
|
||||
# fastapi-users-db-sqlalchemy
|
||||
# onyx
|
||||
sse-starlette==3.0.3
|
||||
# via mcp
|
||||
sseclient-py==1.8.0
|
||||
@@ -888,11 +985,14 @@ starlette==0.49.3
|
||||
# via
|
||||
# fastapi
|
||||
# mcp
|
||||
# onyx
|
||||
# prometheus-fastapi-instrumentator
|
||||
stone==3.3.1
|
||||
# via dropbox
|
||||
stripe==10.12.0
|
||||
# via onyx
|
||||
supervisor==4.3.0
|
||||
# via onyx
|
||||
sympy==1.14.0
|
||||
# via onnxruntime
|
||||
tblib==3.2.2
|
||||
@@ -905,8 +1005,11 @@ tenacity==9.1.2
|
||||
text-unidecode==1.3
|
||||
# via python-slugify
|
||||
tiktoken==0.7.0
|
||||
# via litellm
|
||||
# via
|
||||
# litellm
|
||||
# onyx
|
||||
timeago==1.0.16
|
||||
# via onyx
|
||||
tld==0.13.1
|
||||
# via courlan
|
||||
tokenizers==0.21.4
|
||||
@@ -930,11 +1033,13 @@ tqdm==4.67.1
|
||||
# openai
|
||||
# unstructured
|
||||
trafilatura==1.12.2
|
||||
# via onyx
|
||||
typer==0.20.0
|
||||
# via mcp
|
||||
types-awscrt==0.28.4
|
||||
# via botocore-stubs
|
||||
types-openpyxl==3.0.4.7
|
||||
# via onyx
|
||||
types-requests==2.32.0.20250328
|
||||
# via cohere
|
||||
types-s3transfer==0.14.0
|
||||
@@ -1000,8 +1105,11 @@ tzlocal==5.3.1
|
||||
uncalled-for==0.2.0
|
||||
# via fastmcp
|
||||
unstructured==0.18.27
|
||||
# via onyx
|
||||
unstructured-client==0.42.6
|
||||
# via unstructured
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
uritemplate==4.2.0
|
||||
# via google-api-python-client
|
||||
urllib3==2.6.3
|
||||
@@ -1013,6 +1121,7 @@ urllib3==2.6.3
|
||||
# htmldate
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
# onyx
|
||||
# opensearch-py
|
||||
# pyairtable
|
||||
# pygithub
|
||||
@@ -1062,7 +1171,9 @@ xlrd==2.0.2
|
||||
xlsxwriter==3.2.9
|
||||
# via python-pptx
|
||||
xmlsec==1.3.14
|
||||
# via python3-saml
|
||||
# via
|
||||
# onyx
|
||||
# python3-saml
|
||||
xmltodict==1.0.2
|
||||
# via ddtrace
|
||||
yarl==1.22.0
|
||||
@@ -1076,3 +1187,4 @@ zipp==3.23.0
|
||||
zstandard==0.23.0
|
||||
# via langsmith
|
||||
zulip==0.8.2
|
||||
# via onyx
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --group dev -o backend/requirements/dev.txt
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --extra dev -o backend/requirements/dev.txt
|
||||
agent-client-protocol==0.7.1
|
||||
# via onyx
|
||||
aioboto3==15.1.0
|
||||
@@ -47,6 +47,7 @@ attrs==25.4.0
|
||||
# jsonschema
|
||||
# referencing
|
||||
black==25.1.0
|
||||
# via onyx
|
||||
boto3==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -59,6 +60,7 @@ botocore==1.39.11
|
||||
brotli==1.2.0
|
||||
# via onyx
|
||||
celery-types==0.19.0
|
||||
# via onyx
|
||||
certifi==2025.11.12
|
||||
# via
|
||||
# httpcore
|
||||
@@ -120,6 +122,7 @@ execnet==2.1.2
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==40.1.2
|
||||
# via onyx
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -153,6 +156,7 @@ h11==0.16.0
|
||||
# httpcore
|
||||
# uvicorn
|
||||
hatchling==1.28.0
|
||||
# via onyx
|
||||
hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
@@ -183,6 +187,7 @@ importlib-metadata==8.7.0
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
ipykernel==6.29.5
|
||||
# via onyx
|
||||
ipython==9.7.0
|
||||
# via ipykernel
|
||||
ipython-pygments-lexers==1.1.1
|
||||
@@ -219,11 +224,13 @@ litellm==1.81.6
|
||||
mako==1.2.4
|
||||
# via alembic
|
||||
manygo==0.2.0
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# jinja2
|
||||
# mako
|
||||
matplotlib==3.10.8
|
||||
# via onyx
|
||||
matplotlib-inline==0.2.1
|
||||
# via
|
||||
# ipykernel
|
||||
@@ -236,10 +243,12 @@ multidict==6.7.0
|
||||
# aiohttp
|
||||
# yarl
|
||||
mypy==1.13.0
|
||||
# via onyx
|
||||
mypy-extensions==1.0.0
|
||||
# via
|
||||
# black
|
||||
# mypy
|
||||
# onyx
|
||||
nest-asyncio==1.6.0
|
||||
# via ipykernel
|
||||
nodeenv==1.9.1
|
||||
@@ -255,12 +264,15 @@ oauthlib==3.2.2
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.3
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
# litellm
|
||||
# onyx
|
||||
openapi-generator-cli==7.17.0
|
||||
# via onyx-devtools
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
packaging==24.2
|
||||
# via
|
||||
# black
|
||||
@@ -270,6 +282,7 @@ packaging==24.2
|
||||
# matplotlib
|
||||
# pytest
|
||||
pandas-stubs==2.3.3.251201
|
||||
# via onyx
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
parso==0.8.5
|
||||
@@ -292,6 +305,7 @@ pluggy==1.6.0
|
||||
# hatchling
|
||||
# pytest
|
||||
pre-commit==3.2.2
|
||||
# via onyx
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -345,16 +359,22 @@ pyparsing==3.2.5
|
||||
# via matplotlib
|
||||
pytest==8.3.5
|
||||
# via
|
||||
# onyx
|
||||
# pytest-alembic
|
||||
# pytest-asyncio
|
||||
# pytest-dotenv
|
||||
# pytest-repeat
|
||||
# pytest-xdist
|
||||
pytest-alembic==0.12.1
|
||||
# via onyx
|
||||
pytest-asyncio==1.3.0
|
||||
# via onyx
|
||||
pytest-dotenv==0.5.2
|
||||
# via onyx
|
||||
pytest-repeat==0.9.4
|
||||
# via onyx
|
||||
pytest-xdist==3.8.0
|
||||
# via onyx
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -387,7 +407,9 @@ referencing==0.36.2
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.5.2
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
@@ -408,6 +430,7 @@ rpds-py==0.29.0
|
||||
rsa==4.9.1
|
||||
# via google-auth
|
||||
ruff==0.12.0
|
||||
# via onyx
|
||||
s3transfer==0.13.1
|
||||
# via boto3
|
||||
sentry-sdk==2.14.0
|
||||
@@ -461,22 +484,39 @@ traitlets==5.14.3
|
||||
trove-classifiers==2025.12.1.14
|
||||
# via hatchling
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
# via onyx
|
||||
types-html5lib==1.1.11.13
|
||||
# via types-beautifulsoup4
|
||||
# via
|
||||
# onyx
|
||||
# types-beautifulsoup4
|
||||
types-oauthlib==3.2.0.9
|
||||
# via onyx
|
||||
types-passlib==1.7.7.20240106
|
||||
# via onyx
|
||||
types-pillow==10.2.0.20240822
|
||||
# via onyx
|
||||
types-psutil==7.1.3.20251125
|
||||
# via onyx
|
||||
types-psycopg2==2.9.21.10
|
||||
# via onyx
|
||||
types-python-dateutil==2.8.19.13
|
||||
# via onyx
|
||||
types-pytz==2023.3.1.1
|
||||
# via pandas-stubs
|
||||
# via
|
||||
# onyx
|
||||
# pandas-stubs
|
||||
types-pyyaml==6.0.12.11
|
||||
# via onyx
|
||||
types-regex==2023.3.23.1
|
||||
# via onyx
|
||||
types-requests==2.32.0.20250328
|
||||
# via cohere
|
||||
# via
|
||||
# cohere
|
||||
# onyx
|
||||
types-retry==0.9.9.3
|
||||
# via onyx
|
||||
types-setuptools==68.0.0.3
|
||||
# via onyx
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
@@ -534,3 +574,4 @@ yarl==1.22.0
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
zizmor==1.18.0
|
||||
# via onyx
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --group ee -o backend/requirements/ee.txt
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --extra ee -o backend/requirements/ee.txt
|
||||
agent-client-protocol==0.7.1
|
||||
# via onyx
|
||||
aioboto3==15.1.0
|
||||
@@ -182,6 +182,7 @@ packaging==24.2
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
posthog==3.7.4
|
||||
# via onyx
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
# onyx
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --group model_server -o backend/requirements/model_server.txt
|
||||
# uv export --no-emit-project --no-default-groups --no-hashes --extra model_server -o backend/requirements/model_server.txt
|
||||
accelerate==1.6.0
|
||||
# via onyx
|
||||
agent-client-protocol==0.7.1
|
||||
# via onyx
|
||||
aioboto3==15.1.0
|
||||
@@ -104,6 +105,7 @@ distro==1.9.0
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
# via onyx
|
||||
fastapi==0.133.1
|
||||
# via
|
||||
# onyx
|
||||
@@ -205,6 +207,7 @@ networkx==3.5
|
||||
numpy==2.4.1
|
||||
# via
|
||||
# accelerate
|
||||
# onyx
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# transformers
|
||||
@@ -360,6 +363,7 @@ s3transfer==0.13.1
|
||||
safetensors==0.5.3
|
||||
# via
|
||||
# accelerate
|
||||
# onyx
|
||||
# transformers
|
||||
scikit-learn==1.7.2
|
||||
# via sentence-transformers
|
||||
@@ -368,6 +372,7 @@ scipy==1.16.3
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
sentence-transformers==4.0.2
|
||||
# via onyx
|
||||
sentry-sdk==2.14.0
|
||||
# via onyx
|
||||
setuptools==80.9.0 ; python_full_version >= '3.12'
|
||||
@@ -406,6 +411,7 @@ tokenizers==0.21.4
|
||||
torch==2.9.1
|
||||
# via
|
||||
# accelerate
|
||||
# onyx
|
||||
# sentence-transformers
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
@@ -414,7 +420,9 @@ tqdm==4.67.1
|
||||
# sentence-transformers
|
||||
# transformers
|
||||
transformers==4.53.0
|
||||
# via sentence-transformers
|
||||
# via
|
||||
# onyx
|
||||
# sentence-transformers
|
||||
triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
types-requests==2.32.0.20250328
|
||||
|
||||
@@ -9,7 +9,6 @@ from unittest.mock import patch
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from ee.onyx.db.license import delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import get_used_seats
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
@@ -215,43 +214,3 @@ class TestCheckSeatAvailabilityMultiTenant:
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
|
||||
class TestGetUsedSeatsAccountTypeFiltering:
|
||||
"""Verify get_used_seats query excludes SERVICE_ACCOUNT but includes BOT."""
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_excludes_service_accounts(self, mock_get_session: MagicMock) -> None:
|
||||
"""SERVICE_ACCOUNT users should not count toward seats."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 5
|
||||
|
||||
result = get_used_seats()
|
||||
|
||||
assert result == 5
|
||||
# Inspect the compiled query to verify account_type filter
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "SERVICE_ACCOUNT" in compiled
|
||||
# BOT should NOT be excluded
|
||||
assert "BOT" not in compiled
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_still_excludes_ext_perm_user(self, mock_get_session: MagicMock) -> None:
|
||||
"""EXT_PERM_USER exclusion should still be present."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 3
|
||||
|
||||
get_used_seats()
|
||||
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "EXT_PERM_USER" in compiled
|
||||
|
||||
@@ -6,7 +6,6 @@ import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import _JIRA_BULK_FETCH_LIMIT
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
@@ -146,29 +145,3 @@ def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
|
||||
|
||||
def test_bulk_fetch_respects_api_batch_limit() -> None:
|
||||
"""Requests to the bulkfetch endpoint never exceed _JIRA_BULK_FETCH_LIMIT IDs."""
|
||||
client = _mock_jira_client()
|
||||
total_issues = _JIRA_BULK_FETCH_LIMIT * 3 + 7
|
||||
all_ids = [str(i) for i in range(total_issues)]
|
||||
|
||||
batch_sizes: list[int] = []
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
batch_sizes.append(len(ids))
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, all_ids)
|
||||
|
||||
assert len(result) == total_issues
|
||||
# keeping this hardcoded because it's the documented limit
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
assert all(size <= 100 for size in batch_sizes)
|
||||
assert len(batch_sizes) == 4
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
"""Tests for _build_thread_text function."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.slack_search import _build_thread_text
|
||||
|
||||
|
||||
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
|
||||
return {"user": user, "text": text, "ts": ts}
|
||||
|
||||
|
||||
class TestBuildThreadText:
|
||||
"""Verify _build_thread_text includes full thread replies up to cap."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
|
||||
"""All replies within cap are included in output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "parent msg", "1000.0"),
|
||||
_make_msg("U2", "reply 1", "1001.0"),
|
||||
_make_msg("U3", "reply 2", "1002.0"),
|
||||
_make_msg("U4", "reply 3", "1003.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "parent msg" in result
|
||||
assert "reply 1" in result
|
||||
assert "reply 2" in result
|
||||
assert "reply 3" in result
|
||||
assert "..." not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
|
||||
"""Single message (no replies) returns just the parent text."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [_make_msg("U1", "just a message", "1000.0")]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "just a message" in result
|
||||
assert "Replies:" not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
|
||||
"""Thread parent message is always the first line of output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "I am the parent", "1000.0"),
|
||||
_make_msg("U2", "I am a reply", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
parent_pos = result.index("I am the parent")
|
||||
reply_pos = result.index("I am a reply")
|
||||
assert parent_pos < reply_pos
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
|
||||
"""User IDs in thread text are replaced with display names."""
|
||||
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
|
||||
messages = [
|
||||
_make_msg("U1", "hello", "1000.0"),
|
||||
_make_msg("U2", "world", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "<@U1>" not in result
|
||||
assert "<@U2>" not in result
|
||||
@@ -1,108 +0,0 @@
|
||||
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
|
||||
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
|
||||
|
||||
|
||||
class TestExtractSlackMessageUrls:
|
||||
"""Verify URL parsing extracts channel_id and timestamp correctly."""
|
||||
|
||||
def test_standard_url(self) -> None:
|
||||
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 1
|
||||
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
|
||||
|
||||
def test_multiple_urls(self) -> None:
|
||||
query = (
|
||||
"compare https://co.slack.com/archives/C111/p1234567890123456 "
|
||||
"and https://co.slack.com/archives/C222/p9876543210987654"
|
||||
)
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 2
|
||||
assert results[0] == ("C111", "1234567890.123456")
|
||||
assert results[1] == ("C222", "9876543210.987654")
|
||||
|
||||
def test_no_urls(self) -> None:
|
||||
query = "what happened in #general last week?"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_non_slack_url_ignored(self) -> None:
|
||||
query = "check https://google.com/archives/C111/p1234567890123456"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_timestamp_conversion(self) -> None:
|
||||
"""p prefix removed, dot inserted after 10th digit."""
|
||||
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
channel_id, ts = results[0]
|
||||
assert channel_id == "CABC123"
|
||||
assert ts == "1775491616.524769"
|
||||
assert not ts.startswith("p")
|
||||
assert "." in ts
|
||||
|
||||
|
||||
class TestFetchThreadFromUrl:
|
||||
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search._build_thread_text")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_successful_fetch(
|
||||
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
|
||||
# Mock conversations_replies
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = [
|
||||
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
|
||||
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
|
||||
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
|
||||
]
|
||||
mock_client.conversations_replies.return_value = mock_response
|
||||
|
||||
# Mock channel info
|
||||
mock_ch_response = MagicMock()
|
||||
mock_ch_response.get.return_value = {"name": "general"}
|
||||
mock_client.conversations_info.return_value = mock_ch_response
|
||||
|
||||
mock_build_thread.return_value = (
|
||||
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(
|
||||
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
|
||||
)
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
|
||||
assert len(result.messages) == 1
|
||||
msg = result.messages[0]
|
||||
assert msg.channel_id == "C097NBWMY8Y"
|
||||
assert msg.thread_id is None # Prevents double-enrichment
|
||||
assert msg.slack_score == 100000.0
|
||||
assert "parent" in msg.text
|
||||
mock_client.conversations_replies.assert_called_once_with(
|
||||
channel="C097NBWMY8Y", ts="1775491616.524769"
|
||||
)
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
mock_client.conversations_replies.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response=MagicMock(status_code=404),
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
assert len(result.messages) == 0
|
||||
@@ -505,7 +505,6 @@ class TestGetLMStudioAvailableModels:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.api_base = "http://localhost:1234"
|
||||
mock_provider.custom_config = {"LM_STUDIO_API_KEY": "stored-secret"}
|
||||
|
||||
response = {
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -10,9 +9,7 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _check_seat_availability
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -744,80 +741,3 @@ class TestEmailCasePreservation:
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
|
||||
class TestSeatLock:
|
||||
"""Tests for the advisory lock in _check_seat_availability."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
|
||||
def test_acquires_advisory_lock_before_checking(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The advisory lock must be acquired before the seat check runs."""
|
||||
call_order: list[str] = []
|
||||
|
||||
def track_execute(stmt: Any, _params: Any = None) -> None:
|
||||
if "pg_advisory_xact_lock" in str(stmt):
|
||||
call_order.append("lock")
|
||||
|
||||
mock_dal.session.execute.side_effect = track_execute
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
|
||||
) as mock_fetch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_fn = MagicMock(return_value=mock_result)
|
||||
mock_fetch.return_value = mock_fn
|
||||
|
||||
def track_check(*_args: Any, **_kwargs: Any) -> Any:
|
||||
call_order.append("check")
|
||||
return mock_result
|
||||
|
||||
mock_fn.side_effect = track_check
|
||||
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
assert call_order == ["lock", "check"]
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
|
||||
def test_lock_uses_tenant_scoped_key(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_check = MagicMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=mock_check,
|
||||
):
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
mock_dal.session.execute.assert_called_once()
|
||||
params = mock_dal.session.execute.call_args[0][1]
|
||||
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
|
||||
|
||||
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
|
||||
"""Lock id must be deterministic and differ across tenants."""
|
||||
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
|
||||
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
|
||||
|
||||
def test_no_lock_when_ee_absent(
|
||||
self,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""No advisory lock should be acquired when the EE check is absent."""
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=None,
|
||||
):
|
||||
result = _check_seat_availability(mock_dal)
|
||||
|
||||
assert result is None
|
||||
mock_dal.session.execute.assert_not_called()
|
||||
|
||||
@@ -70,10 +70,6 @@ spec:
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9094
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_heavy.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -28,7 +28,7 @@ dependencies = [
|
||||
"kubernetes>=31.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
[project.optional-dependencies]
|
||||
# Main backend application dependencies
|
||||
backend = [
|
||||
"aiohttp==3.13.4",
|
||||
@@ -195,9 +195,6 @@ model_server = [
|
||||
"sentry-sdk[fastapi,celery,starlette]==2.14.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
default-groups = ["backend", "dev", "ee", "model_server"]
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
@@ -233,7 +230,7 @@ follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["tools/ods"]
|
||||
members = ["backend", "tools/ods"]
|
||||
|
||||
[tool.basedpyright]
|
||||
include = ["backend"]
|
||||
|
||||
310
uv.lock
generated
310
uv.lock
generated
@@ -14,6 +14,12 @@ resolution-markers = [
|
||||
"python_full_version < '3.12' and sys_platform != 'win32'",
|
||||
]
|
||||
|
||||
[manifest]
|
||||
members = [
|
||||
"onyx",
|
||||
"onyx-backend",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
version = "1.6.0"
|
||||
@@ -4228,7 +4234,7 @@ dependencies = [
|
||||
{ name = "voyageai" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
[package.optional-dependencies]
|
||||
backend = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "alembic" },
|
||||
@@ -4382,175 +4388,179 @@ model-server = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'model-server'", specifier = "==1.6.0" },
|
||||
{ name = "agent-client-protocol", specifier = ">=0.7.1" },
|
||||
{ name = "aioboto3", specifier = "==15.1.0" },
|
||||
{ name = "aiohttp", marker = "extra == 'backend'", specifier = "==3.13.4" },
|
||||
{ name = "alembic", marker = "extra == 'backend'", specifier = "==1.10.4" },
|
||||
{ name = "asana", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "asyncpg", marker = "extra == 'backend'", specifier = "==0.30.0" },
|
||||
{ name = "atlassian-python-api", marker = "extra == 'backend'", specifier = "==3.41.16" },
|
||||
{ name = "azure-cognitiveservices-speech", marker = "extra == 'backend'", specifier = "==1.38.0" },
|
||||
{ name = "beautifulsoup4", marker = "extra == 'backend'", specifier = "==4.12.3" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = "==25.1.0" },
|
||||
{ name = "boto3", marker = "extra == 'backend'", specifier = "==1.39.11" },
|
||||
{ name = "boto3-stubs", extras = ["s3"], marker = "extra == 'backend'", specifier = "==1.39.11" },
|
||||
{ name = "braintrust", marker = "extra == 'backend'", specifier = "==0.3.9" },
|
||||
{ name = "brotli", specifier = ">=1.2.0" },
|
||||
{ name = "celery", marker = "extra == 'backend'", specifier = "==5.5.1" },
|
||||
{ name = "celery-types", marker = "extra == 'dev'", specifier = "==0.19.0" },
|
||||
{ name = "chardet", marker = "extra == 'backend'", specifier = "==5.2.0" },
|
||||
{ name = "chonkie", marker = "extra == 'backend'", specifier = "==1.0.10" },
|
||||
{ name = "claude-agent-sdk", specifier = ">=0.1.19" },
|
||||
{ name = "cohere", specifier = "==5.6.1" },
|
||||
{ name = "dask", marker = "extra == 'backend'", specifier = "==2026.1.1" },
|
||||
{ name = "ddtrace", marker = "extra == 'backend'", specifier = "==3.10.0" },
|
||||
{ name = "discord-py", specifier = "==2.4.0" },
|
||||
{ name = "discord-py", marker = "extra == 'backend'", specifier = "==2.4.0" },
|
||||
{ name = "distributed", marker = "extra == 'backend'", specifier = "==2026.1.1" },
|
||||
{ name = "dropbox", marker = "extra == 'backend'", specifier = "==12.0.2" },
|
||||
{ name = "einops", marker = "extra == 'model-server'", specifier = "==0.8.1" },
|
||||
{ name = "exa-py", marker = "extra == 'backend'", specifier = "==1.15.4" },
|
||||
{ name = "faker", marker = "extra == 'dev'", specifier = "==40.1.2" },
|
||||
{ name = "fastapi", specifier = "==0.133.1" },
|
||||
{ name = "fastapi-limiter", marker = "extra == 'backend'", specifier = "==0.1.6" },
|
||||
{ name = "fastapi-users", marker = "extra == 'backend'", specifier = "==15.0.4" },
|
||||
{ name = "fastapi-users-db-sqlalchemy", marker = "extra == 'backend'", specifier = "==7.0.0" },
|
||||
{ name = "fastmcp", marker = "extra == 'backend'", specifier = "==3.2.0" },
|
||||
{ name = "filelock", marker = "extra == 'backend'", specifier = "==3.20.3" },
|
||||
{ name = "google-api-python-client", marker = "extra == 'backend'", specifier = "==2.86.0" },
|
||||
{ name = "google-auth-httplib2", marker = "extra == 'backend'", specifier = "==0.1.0" },
|
||||
{ name = "google-auth-oauthlib", marker = "extra == 'backend'", specifier = "==1.0.0" },
|
||||
{ name = "google-genai", specifier = "==1.52.0" },
|
||||
{ name = "hatchling", marker = "extra == 'dev'", specifier = "==1.28.0" },
|
||||
{ name = "httpcore", marker = "extra == 'backend'", specifier = "==1.0.9" },
|
||||
{ name = "httpx", extras = ["http2"], marker = "extra == 'backend'", specifier = "==0.28.1" },
|
||||
{ name = "httpx-oauth", marker = "extra == 'backend'", specifier = "==0.15.1" },
|
||||
{ name = "hubspot-api-client", marker = "extra == 'backend'", specifier = "==11.1.0" },
|
||||
{ name = "huggingface-hub", marker = "extra == 'backend'", specifier = "==0.35.3" },
|
||||
{ name = "inflection", marker = "extra == 'backend'", specifier = "==0.5.1" },
|
||||
{ name = "ipykernel", marker = "extra == 'dev'", specifier = "==6.29.5" },
|
||||
{ name = "jira", marker = "extra == 'backend'", specifier = "==3.10.5" },
|
||||
{ name = "jsonref", marker = "extra == 'backend'", specifier = "==1.1.0" },
|
||||
{ name = "kubernetes", specifier = ">=31.0.0" },
|
||||
{ name = "kubernetes", marker = "extra == 'backend'", specifier = "==31.0.0" },
|
||||
{ name = "langchain-core", marker = "extra == 'backend'", specifier = "==1.2.22" },
|
||||
{ name = "langfuse", marker = "extra == 'backend'", specifier = "==3.10.0" },
|
||||
{ name = "lazy-imports", marker = "extra == 'backend'", specifier = "==1.0.1" },
|
||||
{ name = "litellm", specifier = "==1.81.6" },
|
||||
{ name = "lxml", marker = "extra == 'backend'", specifier = "==5.3.0" },
|
||||
{ name = "mako", marker = "extra == 'backend'", specifier = "==1.2.4" },
|
||||
{ name = "manygo", marker = "extra == 'dev'", specifier = "==0.2.0" },
|
||||
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], marker = "extra == 'backend'", specifier = "==0.1.2" },
|
||||
{ name = "matplotlib", marker = "extra == 'dev'", specifier = "==3.10.8" },
|
||||
{ name = "mcp", extras = ["cli"], marker = "extra == 'backend'", specifier = "==1.26.0" },
|
||||
{ name = "mistune", marker = "extra == 'backend'", specifier = "==3.2.0" },
|
||||
{ name = "msal", marker = "extra == 'backend'", specifier = "==1.34.0" },
|
||||
{ name = "msoffcrypto-tool", marker = "extra == 'backend'", specifier = "==5.4.2" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = "==1.13.0" },
|
||||
{ name = "mypy-extensions", marker = "extra == 'dev'", specifier = "==1.0.0" },
|
||||
{ name = "nest-asyncio", marker = "extra == 'backend'", specifier = "==1.6.0" },
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.7.3" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
{ name = "openpyxl", marker = "extra == 'backend'", specifier = "==3.0.10" },
|
||||
{ name = "opensearch-py", marker = "extra == 'backend'", specifier = "==3.0.0" },
|
||||
{ name = "opentelemetry-proto", marker = "extra == 'backend'", specifier = ">=1.39.0" },
|
||||
{ name = "pandas-stubs", marker = "extra == 'dev'", specifier = "~=2.3.3" },
|
||||
{ name = "passlib", marker = "extra == 'backend'", specifier = "==1.7.4" },
|
||||
{ name = "playwright", marker = "extra == 'backend'", specifier = "==1.55.0" },
|
||||
{ name = "posthog", marker = "extra == 'ee'", specifier = "==3.7.4" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = "==3.2.2" },
|
||||
{ name = "prometheus-client", specifier = ">=0.21.1" },
|
||||
{ name = "prometheus-fastapi-instrumentator", specifier = "==7.1.0" },
|
||||
{ name = "psutil", marker = "extra == 'backend'", specifier = "==7.1.3" },
|
||||
{ name = "psycopg2-binary", marker = "extra == 'backend'", specifier = "==2.9.9" },
|
||||
{ name = "puremagic", marker = "extra == 'backend'", specifier = "==1.28" },
|
||||
{ name = "pyairtable", marker = "extra == 'backend'", specifier = "==3.0.1" },
|
||||
{ name = "pycryptodome", marker = "extra == 'backend'", specifier = "==3.19.1" },
|
||||
{ name = "pydantic", specifier = "==2.11.7" },
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.9.2" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
{ name = "pytest-dotenv", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "pytest-mock", marker = "extra == 'backend'", specifier = "==3.12.0" },
|
||||
{ name = "pytest-playwright", marker = "extra == 'backend'", specifier = "==0.7.0" },
|
||||
{ name = "pytest-repeat", marker = "extra == 'dev'", specifier = "==0.9.4" },
|
||||
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = "==3.8.0" },
|
||||
{ name = "python-dateutil", marker = "extra == 'backend'", specifier = "==2.8.2" },
|
||||
{ name = "python-docx", marker = "extra == 'backend'", specifier = "==1.1.2" },
|
||||
{ name = "python-dotenv", marker = "extra == 'backend'", specifier = "==1.1.1" },
|
||||
{ name = "python-gitlab", marker = "extra == 'backend'", specifier = "==5.6.0" },
|
||||
{ name = "python-multipart", marker = "extra == 'backend'", specifier = "==0.0.22" },
|
||||
{ name = "python-pptx", marker = "extra == 'backend'", specifier = "==0.6.23" },
|
||||
{ name = "python3-saml", marker = "extra == 'backend'", specifier = "==1.15.0" },
|
||||
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
|
||||
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.33.0" },
|
||||
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
|
||||
{ name = "retry", specifier = "==0.9.2" },
|
||||
{ name = "rfc3986", marker = "extra == 'backend'", specifier = "==1.5.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = "==0.12.0" },
|
||||
{ name = "safetensors", marker = "extra == 'model-server'", specifier = "==0.5.3" },
|
||||
{ name = "sendgrid", marker = "extra == 'backend'", specifier = "==6.12.5" },
|
||||
{ name = "sentence-transformers", marker = "extra == 'model-server'", specifier = "==4.0.2" },
|
||||
{ name = "sentry-sdk", specifier = "==2.14.0" },
|
||||
{ name = "sentry-sdk", extras = ["fastapi", "celery", "starlette"], marker = "extra == 'model-server'", specifier = "==2.14.0" },
|
||||
{ name = "shapely", marker = "extra == 'backend'", specifier = "==2.0.6" },
|
||||
{ name = "simple-salesforce", marker = "extra == 'backend'", specifier = "==1.12.6" },
|
||||
{ name = "slack-sdk", marker = "extra == 'backend'", specifier = "==3.20.2" },
|
||||
{ name = "sqlalchemy", extras = ["mypy"], marker = "extra == 'backend'", specifier = "==2.0.15" },
|
||||
{ name = "starlette", marker = "extra == 'backend'", specifier = "==0.49.3" },
|
||||
{ name = "stripe", marker = "extra == 'backend'", specifier = "==10.12.0" },
|
||||
{ name = "supervisor", marker = "extra == 'backend'", specifier = "==4.3.0" },
|
||||
{ name = "tiktoken", marker = "extra == 'backend'", specifier = "==0.7.0" },
|
||||
{ name = "timeago", marker = "extra == 'backend'", specifier = "==1.0.16" },
|
||||
{ name = "torch", marker = "extra == 'model-server'", specifier = "==2.9.1" },
|
||||
{ name = "trafilatura", marker = "extra == 'backend'", specifier = "==1.12.2" },
|
||||
{ name = "transformers", marker = "extra == 'model-server'", specifier = "==4.53.0" },
|
||||
{ name = "types-beautifulsoup4", marker = "extra == 'dev'", specifier = "==4.12.0.3" },
|
||||
{ name = "types-html5lib", marker = "extra == 'dev'", specifier = "==1.1.11.13" },
|
||||
{ name = "types-oauthlib", marker = "extra == 'dev'", specifier = "==3.2.0.9" },
|
||||
{ name = "types-openpyxl", marker = "extra == 'backend'", specifier = "==3.0.4.7" },
|
||||
{ name = "types-passlib", marker = "extra == 'dev'", specifier = "==1.7.7.20240106" },
|
||||
{ name = "types-pillow", marker = "extra == 'dev'", specifier = "==10.2.0.20240822" },
|
||||
{ name = "types-psutil", marker = "extra == 'dev'", specifier = "==7.1.3.20251125" },
|
||||
{ name = "types-psycopg2", marker = "extra == 'dev'", specifier = "==2.9.21.10" },
|
||||
{ name = "types-python-dateutil", marker = "extra == 'dev'", specifier = "==2.8.19.13" },
|
||||
{ name = "types-pytz", marker = "extra == 'dev'", specifier = "==2023.3.1.1" },
|
||||
{ name = "types-pyyaml", marker = "extra == 'dev'", specifier = "==6.0.12.11" },
|
||||
{ name = "types-regex", marker = "extra == 'dev'", specifier = "==2023.3.23.1" },
|
||||
{ name = "types-requests", marker = "extra == 'dev'", specifier = "==2.32.0.20250328" },
|
||||
{ name = "types-retry", marker = "extra == 'dev'", specifier = "==0.9.9.3" },
|
||||
{ name = "types-setuptools", marker = "extra == 'dev'", specifier = "==68.0.0.3" },
|
||||
{ name = "unstructured", marker = "extra == 'backend'", specifier = "==0.18.27" },
|
||||
{ name = "unstructured-client", marker = "extra == 'backend'", specifier = "==0.42.6" },
|
||||
{ name = "urllib3", marker = "extra == 'backend'", specifier = "==2.6.3" },
|
||||
{ name = "uvicorn", specifier = "==0.35.0" },
|
||||
{ name = "voyageai", specifier = "==0.2.3" },
|
||||
{ name = "xmlsec", marker = "extra == 'backend'", specifier = "==1.3.14" },
|
||||
{ name = "zizmor", marker = "extra == 'dev'", specifier = "==1.18.0" },
|
||||
{ name = "zulip", marker = "extra == 'backend'", specifier = "==0.8.2" },
|
||||
]
|
||||
provides-extras = ["backend", "dev", "ee", "model-server"]
|
||||
|
||||
[[package]]
|
||||
name = "onyx-backend"
|
||||
version = "0.0.0"
|
||||
source = { virtual = "backend" }
|
||||
dependencies = [
|
||||
{ name = "onyx", extra = ["backend", "dev", "ee"] },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
backend = [
|
||||
{ name = "aiohttp", specifier = "==3.13.4" },
|
||||
{ name = "alembic", specifier = "==1.10.4" },
|
||||
{ name = "asana", specifier = "==5.0.8" },
|
||||
{ name = "asyncpg", specifier = "==0.30.0" },
|
||||
{ name = "atlassian-python-api", specifier = "==3.41.16" },
|
||||
{ name = "azure-cognitiveservices-speech", specifier = "==1.38.0" },
|
||||
{ name = "beautifulsoup4", specifier = "==4.12.3" },
|
||||
{ name = "boto3", specifier = "==1.39.11" },
|
||||
{ name = "boto3-stubs", extras = ["s3"], specifier = "==1.39.11" },
|
||||
{ name = "braintrust", specifier = "==0.3.9" },
|
||||
{ name = "celery", specifier = "==5.5.1" },
|
||||
{ name = "chardet", specifier = "==5.2.0" },
|
||||
{ name = "chonkie", specifier = "==1.0.10" },
|
||||
{ name = "dask", specifier = "==2026.1.1" },
|
||||
{ name = "ddtrace", specifier = "==3.10.0" },
|
||||
{ name = "discord-py", specifier = "==2.4.0" },
|
||||
{ name = "distributed", specifier = "==2026.1.1" },
|
||||
{ name = "dropbox", specifier = "==12.0.2" },
|
||||
{ name = "exa-py", specifier = "==1.15.4" },
|
||||
{ name = "fastapi-limiter", specifier = "==0.1.6" },
|
||||
{ name = "fastapi-users", specifier = "==15.0.4" },
|
||||
{ name = "fastapi-users-db-sqlalchemy", specifier = "==7.0.0" },
|
||||
{ name = "fastmcp", specifier = "==3.2.0" },
|
||||
{ name = "filelock", specifier = "==3.20.3" },
|
||||
{ name = "google-api-python-client", specifier = "==2.86.0" },
|
||||
{ name = "google-auth-httplib2", specifier = "==0.1.0" },
|
||||
{ name = "google-auth-oauthlib", specifier = "==1.0.0" },
|
||||
{ name = "httpcore", specifier = "==1.0.9" },
|
||||
{ name = "httpx", extras = ["http2"], specifier = "==0.28.1" },
|
||||
{ name = "httpx-oauth", specifier = "==0.15.1" },
|
||||
{ name = "hubspot-api-client", specifier = "==11.1.0" },
|
||||
{ name = "huggingface-hub", specifier = "==0.35.3" },
|
||||
{ name = "inflection", specifier = "==0.5.1" },
|
||||
{ name = "jira", specifier = "==3.10.5" },
|
||||
{ name = "jsonref", specifier = "==1.1.0" },
|
||||
{ name = "kubernetes", specifier = "==31.0.0" },
|
||||
{ name = "langchain-core", specifier = "==1.2.22" },
|
||||
{ name = "langfuse", specifier = "==3.10.0" },
|
||||
{ name = "lazy-imports", specifier = "==1.0.1" },
|
||||
{ name = "lxml", specifier = "==5.3.0" },
|
||||
{ name = "mako", specifier = "==1.2.4" },
|
||||
{ name = "markitdown", extras = ["pdf", "docx", "pptx", "xlsx", "xls"], specifier = "==0.1.2" },
|
||||
{ name = "mcp", extras = ["cli"], specifier = "==1.26.0" },
|
||||
{ name = "mistune", specifier = "==3.2.0" },
|
||||
{ name = "msal", specifier = "==1.34.0" },
|
||||
{ name = "msoffcrypto-tool", specifier = "==5.4.2" },
|
||||
{ name = "nest-asyncio", specifier = "==1.6.0" },
|
||||
{ name = "oauthlib", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", specifier = "==2.6.2" },
|
||||
{ name = "openinference-instrumentation", specifier = "==0.1.42" },
|
||||
{ name = "openpyxl", specifier = "==3.0.10" },
|
||||
{ name = "opensearch-py", specifier = "==3.0.0" },
|
||||
{ name = "opentelemetry-proto", specifier = ">=1.39.0" },
|
||||
{ name = "passlib", specifier = "==1.7.4" },
|
||||
{ name = "playwright", specifier = "==1.55.0" },
|
||||
{ name = "psutil", specifier = "==7.1.3" },
|
||||
{ name = "psycopg2-binary", specifier = "==2.9.9" },
|
||||
{ name = "puremagic", specifier = "==1.28" },
|
||||
{ name = "pyairtable", specifier = "==3.0.1" },
|
||||
{ name = "pycryptodome", specifier = "==3.19.1" },
|
||||
{ name = "pygithub", specifier = "==2.5.0" },
|
||||
{ name = "pympler", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", specifier = "==6.9.2" },
|
||||
{ name = "pytest-mock", specifier = "==3.12.0" },
|
||||
{ name = "pytest-playwright", specifier = "==0.7.0" },
|
||||
{ name = "python-dateutil", specifier = "==2.8.2" },
|
||||
{ name = "python-docx", specifier = "==1.1.2" },
|
||||
{ name = "python-dotenv", specifier = "==1.1.1" },
|
||||
{ name = "python-gitlab", specifier = "==5.6.0" },
|
||||
{ name = "python-multipart", specifier = "==0.0.22" },
|
||||
{ name = "python-pptx", specifier = "==0.6.23" },
|
||||
{ name = "python3-saml", specifier = "==1.15.0" },
|
||||
{ name = "pywikibot", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", specifier = "==3.13.0" },
|
||||
{ name = "redis", specifier = "==5.0.8" },
|
||||
{ name = "requests", specifier = "==2.33.0" },
|
||||
{ name = "requests-oauthlib", specifier = "==1.3.1" },
|
||||
{ name = "rfc3986", specifier = "==1.5.0" },
|
||||
{ name = "sendgrid", specifier = "==6.12.5" },
|
||||
{ name = "shapely", specifier = "==2.0.6" },
|
||||
{ name = "simple-salesforce", specifier = "==1.12.6" },
|
||||
{ name = "slack-sdk", specifier = "==3.20.2" },
|
||||
{ name = "sqlalchemy", extras = ["mypy"], specifier = "==2.0.15" },
|
||||
{ name = "starlette", specifier = "==0.49.3" },
|
||||
{ name = "stripe", specifier = "==10.12.0" },
|
||||
{ name = "supervisor", specifier = "==4.3.0" },
|
||||
{ name = "tiktoken", specifier = "==0.7.0" },
|
||||
{ name = "timeago", specifier = "==1.0.16" },
|
||||
{ name = "trafilatura", specifier = "==1.12.2" },
|
||||
{ name = "types-openpyxl", specifier = "==3.0.4.7" },
|
||||
{ name = "unstructured", specifier = "==0.18.27" },
|
||||
{ name = "unstructured-client", specifier = "==0.42.6" },
|
||||
{ name = "urllib3", specifier = "==2.6.3" },
|
||||
{ name = "xmlsec", specifier = "==1.3.14" },
|
||||
{ name = "zulip", specifier = "==0.8.2" },
|
||||
]
|
||||
dev = [
|
||||
{ name = "black", specifier = "==25.1.0" },
|
||||
{ name = "celery-types", specifier = "==0.19.0" },
|
||||
{ name = "faker", specifier = "==40.1.2" },
|
||||
{ name = "hatchling", specifier = "==1.28.0" },
|
||||
{ name = "ipykernel", specifier = "==6.29.5" },
|
||||
{ name = "manygo", specifier = "==0.2.0" },
|
||||
{ name = "matplotlib", specifier = "==3.10.8" },
|
||||
{ name = "mypy", specifier = "==1.13.0" },
|
||||
{ name = "mypy-extensions", specifier = "==1.0.0" },
|
||||
{ name = "onyx-devtools", specifier = "==0.7.3" },
|
||||
{ name = "openapi-generator-cli", specifier = "==7.17.0" },
|
||||
{ name = "pandas-stubs", specifier = "~=2.3.3" },
|
||||
{ name = "pre-commit", specifier = "==3.2.2" },
|
||||
{ name = "pytest", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", specifier = "==1.3.0" },
|
||||
{ name = "pytest-dotenv", specifier = "==0.5.2" },
|
||||
{ name = "pytest-repeat", specifier = "==0.9.4" },
|
||||
{ name = "pytest-xdist", specifier = "==3.8.0" },
|
||||
{ name = "release-tag", specifier = "==0.5.2" },
|
||||
{ name = "reorder-python-imports-black", specifier = "==3.14.0" },
|
||||
{ name = "ruff", specifier = "==0.12.0" },
|
||||
{ name = "types-beautifulsoup4", specifier = "==4.12.0.3" },
|
||||
{ name = "types-html5lib", specifier = "==1.1.11.13" },
|
||||
{ name = "types-oauthlib", specifier = "==3.2.0.9" },
|
||||
{ name = "types-passlib", specifier = "==1.7.7.20240106" },
|
||||
{ name = "types-pillow", specifier = "==10.2.0.20240822" },
|
||||
{ name = "types-psutil", specifier = "==7.1.3.20251125" },
|
||||
{ name = "types-psycopg2", specifier = "==2.9.21.10" },
|
||||
{ name = "types-python-dateutil", specifier = "==2.8.19.13" },
|
||||
{ name = "types-pytz", specifier = "==2023.3.1.1" },
|
||||
{ name = "types-pyyaml", specifier = "==6.0.12.11" },
|
||||
{ name = "types-regex", specifier = "==2023.3.23.1" },
|
||||
{ name = "types-requests", specifier = "==2.32.0.20250328" },
|
||||
{ name = "types-retry", specifier = "==0.9.9.3" },
|
||||
{ name = "types-setuptools", specifier = "==68.0.0.3" },
|
||||
{ name = "zizmor", specifier = "==1.18.0" },
|
||||
]
|
||||
ee = [{ name = "posthog", specifier = "==3.7.4" }]
|
||||
model-server = [
|
||||
{ name = "accelerate", specifier = "==1.6.0" },
|
||||
{ name = "einops", specifier = "==0.8.1" },
|
||||
{ name = "numpy", specifier = "==2.4.1" },
|
||||
{ name = "safetensors", specifier = "==0.5.3" },
|
||||
{ name = "sentence-transformers", specifier = "==4.0.2" },
|
||||
{ name = "sentry-sdk", extras = ["fastapi", "celery", "starlette"], specifier = "==2.14.0" },
|
||||
{ name = "torch", specifier = "==2.9.1" },
|
||||
{ name = "transformers", specifier = "==4.53.0" },
|
||||
]
|
||||
[package.metadata]
|
||||
requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable = "." }]
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
|
||||
16
web/package-lock.json
generated
16
web/package-lock.json
generated
@@ -47,7 +47,6 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
@@ -8844,15 +8843,6 @@
|
||||
"react": ">= 16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/copy-to-clipboard": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz",
|
||||
"integrity": "sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"toggle-selection": "^1.0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/core-js": {
|
||||
"version": "3.46.0",
|
||||
"hasInstallScript": true,
|
||||
@@ -17436,12 +17426,6 @@
|
||||
"node": ">=8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/toggle-selection": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/toggle-selection/-/toggle-selection-1.0.6.tgz",
|
||||
"integrity": "sha512-BiZS+C1OS8g/q2RRbJmy59xpyghNBqrr6k5L/uKBGRsTfxmu3ffiRnd8mlGPUVayg8pvfi5urfnu8TU7DVOkLQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/toposort": {
|
||||
"version": "2.0.2",
|
||||
"license": "MIT"
|
||||
|
||||
@@ -65,7 +65,6 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
|
||||
@@ -17,7 +17,6 @@ import DocumentSetCard from "@/sections/cards/DocumentSetCard";
|
||||
import CollapsibleSection from "@/app/admin/agents/CollapsibleSection";
|
||||
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { StandardAnswerCategoryDropdownField } from "@/components/standardAnswers/StandardAnswerCategoryDropdown";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import { RadioGroup } from "@/components/ui/radio-group";
|
||||
import { RadioGroupItemField } from "@/components/ui/RadioGroupItemField";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
@@ -127,24 +126,6 @@ export function SlackChannelConfigFormFields({
|
||||
return documentSets.filter((ds) => !documentSetContainsSync(ds));
|
||||
}, [documentSets]);
|
||||
|
||||
const searchAgentOptions = useMemo(
|
||||
() =>
|
||||
availableAgents.map((persona) => ({
|
||||
label: persona.name,
|
||||
value: String(persona.id),
|
||||
})),
|
||||
[availableAgents]
|
||||
);
|
||||
|
||||
const nonSearchAgentOptions = useMemo(
|
||||
() =>
|
||||
nonSearchAgents.map((persona) => ({
|
||||
label: persona.name,
|
||||
value: String(persona.id),
|
||||
})),
|
||||
[nonSearchAgents]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const invalidSelected = values.document_sets.filter((dsId: number) =>
|
||||
unselectableSets.some((us) => us.id === dsId)
|
||||
@@ -374,14 +355,12 @@ export function SlackChannelConfigFormFields({
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<InputComboBox
|
||||
placeholder="Search for an agent..."
|
||||
value={String(values.persona_id ?? "")}
|
||||
onValueChange={(val) =>
|
||||
setFieldValue("persona_id", val ? Number(val) : null)
|
||||
}
|
||||
options={searchAgentOptions}
|
||||
strict
|
||||
<SelectorFormField
|
||||
name="persona_id"
|
||||
options={availableAgents.map((persona) => ({
|
||||
name: persona.name,
|
||||
value: persona.id,
|
||||
}))}
|
||||
/>
|
||||
{viewSyncEnabledAgents && syncEnabledAgents.length > 0 && (
|
||||
<div className="mt-4">
|
||||
@@ -440,14 +419,12 @@ export function SlackChannelConfigFormFields({
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<InputComboBox
|
||||
placeholder="Search for an agent..."
|
||||
value={String(values.persona_id ?? "")}
|
||||
onValueChange={(val) =>
|
||||
setFieldValue("persona_id", val ? Number(val) : null)
|
||||
}
|
||||
options={nonSearchAgentOptions}
|
||||
strict
|
||||
<SelectorFormField
|
||||
name="persona_id"
|
||||
options={nonSearchAgents.map((persona) => ({
|
||||
name: persona.name,
|
||||
value: persona.id,
|
||||
}))}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { defaultTailwindCSS } from "@/components/icons/icons";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import { IconProps } from "@opal/types";
|
||||
|
||||
export interface ModelIconProps extends IconProps {
|
||||
|
||||
@@ -1 +1 @@
|
||||
export { default } from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";
|
||||
|
||||
@@ -5,7 +5,7 @@ import { Button } from "@opal/components";
|
||||
import { Text } from "@opal/components";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { SvgEyeOff, SvgX } from "@opal/icons";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import AgentMessage, {
|
||||
AgentMessageProps,
|
||||
} from "@/app/app/message/messageComponents/AgentMessage";
|
||||
@@ -28,8 +28,6 @@ export interface MultiModelPanelProps {
|
||||
isNonPreferredInSelection: boolean;
|
||||
/** Callback when user clicks this panel to select as preferred */
|
||||
onSelect: () => void;
|
||||
/** Callback to deselect this panel as preferred */
|
||||
onDeselect?: () => void;
|
||||
/** Callback to hide/show this panel */
|
||||
onToggleVisibility: () => void;
|
||||
/** Props to pass through to AgentMessage */
|
||||
@@ -65,7 +63,6 @@ export default function MultiModelPanel({
|
||||
isHidden,
|
||||
isNonPreferredInSelection,
|
||||
onSelect,
|
||||
onDeselect,
|
||||
onToggleVisibility,
|
||||
agentMessageProps,
|
||||
errorMessage,
|
||||
@@ -96,25 +93,11 @@ export default function MultiModelPanel({
|
||||
rightChildren={
|
||||
<div className="flex items-center gap-1 px-2">
|
||||
{isPreferred && (
|
||||
<>
|
||||
<span className="text-action-link-05 shrink-0">
|
||||
<Text font="secondary-body" color="inherit" nowrap>
|
||||
Preferred Response
|
||||
</Text>
|
||||
</span>
|
||||
{onDeselect && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgX}
|
||||
size="sm"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onDeselect();
|
||||
}}
|
||||
tooltip="Deselect preferred response"
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
<span className="text-action-link-05 shrink-0">
|
||||
<Text font="secondary-body" color="inherit" nowrap>
|
||||
Preferred Response
|
||||
</Text>
|
||||
</span>
|
||||
)}
|
||||
{!isPreferred && (
|
||||
<Button
|
||||
|
||||
@@ -30,7 +30,7 @@ const SELECTION_PANEL_W = 400;
|
||||
// Compact width for hidden panels in the carousel track
|
||||
const HIDDEN_PANEL_W = 220;
|
||||
// Generation-mode panel widths (from Figma)
|
||||
const GEN_PANEL_W_2 = 720; // 2 panels side-by-side
|
||||
const GEN_PANEL_W_2 = 640; // 2 panels side-by-side
|
||||
const GEN_PANEL_W_3 = 436; // 3 panels side-by-side
|
||||
// Gap between panels — matches CSS gap-6 (24px)
|
||||
const PANEL_GAP = 24;
|
||||
@@ -64,31 +64,14 @@ export default function MultiModelResponseView({
|
||||
onMessageSelection,
|
||||
onHiddenPanelsChange,
|
||||
}: MultiModelResponseViewProps) {
|
||||
// Initialize preferredIndex from the backend's preferred_response_id when
|
||||
// loading an existing conversation.
|
||||
const [preferredIndex, setPreferredIndex] = useState<number | null>(() => {
|
||||
if (!parentMessage?.preferredResponseId) return null;
|
||||
const match = responses.find(
|
||||
(r) => r.messageId === parentMessage.preferredResponseId
|
||||
);
|
||||
return match?.modelIndex ?? null;
|
||||
});
|
||||
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
|
||||
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
|
||||
// Controls animation: false = panels at start position, true = panels at peek position
|
||||
const [selectionEntered, setSelectionEntered] = useState(
|
||||
() => preferredIndex !== null
|
||||
);
|
||||
// Tracks the deselect animation timeout so it can be cancelled if the user
|
||||
// re-selects a panel during the 450ms animation window.
|
||||
const deselectTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
// True while the reverse animation is playing (deselect → back to equal panels)
|
||||
const [selectionExiting, setSelectionExiting] = useState(false);
|
||||
const [selectionEntered, setSelectionEntered] = useState(false);
|
||||
// Measures the overflow-hidden carousel container for responsive preferred-panel sizing.
|
||||
const [trackContainerW, setTrackContainerW] = useState(0);
|
||||
const roRef = useRef<ResizeObserver | null>(null);
|
||||
const trackContainerElRef = useRef<HTMLDivElement | null>(null);
|
||||
const trackContainerRef = useCallback((el: HTMLDivElement | null) => {
|
||||
trackContainerElRef.current = el;
|
||||
if (roRef.current) {
|
||||
roRef.current.disconnect();
|
||||
roRef.current = null;
|
||||
@@ -107,9 +90,6 @@ export default function MultiModelResponseView({
|
||||
number | null
|
||||
>(null);
|
||||
const preferredRoRef = useRef<ResizeObserver | null>(null);
|
||||
// Refs to each panel wrapper for height animation on deselect
|
||||
const panelElsRef = useRef<Map<number, HTMLDivElement>>(new Map());
|
||||
|
||||
// Tracks which non-preferred panels overflow the preferred height cap
|
||||
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
|
||||
new Set()
|
||||
@@ -172,48 +152,15 @@ export default function MultiModelResponseView({
|
||||
const handleSelectPreferred = useCallback(
|
||||
(modelIndex: number) => {
|
||||
if (isGenerating) return;
|
||||
|
||||
// Cancel any pending deselect animation so it doesn't overwrite this selection
|
||||
if (deselectTimeoutRef.current !== null) {
|
||||
clearTimeout(deselectTimeoutRef.current);
|
||||
deselectTimeoutRef.current = null;
|
||||
setSelectionExiting(false);
|
||||
}
|
||||
|
||||
// Only freeze scroll when entering selection mode for the first time.
|
||||
// When switching preferred within selection mode, panels are already
|
||||
// capped and the track just slides — no height changes to worry about.
|
||||
const alreadyInSelection = preferredIndex !== null;
|
||||
if (!alreadyInSelection) {
|
||||
const scrollContainer = trackContainerElRef.current?.closest(
|
||||
"[data-chat-scroll]"
|
||||
) as HTMLElement | null;
|
||||
const scrollTop = scrollContainer?.scrollTop ?? 0;
|
||||
if (scrollContainer) scrollContainer.style.overflow = "hidden";
|
||||
|
||||
setTimeout(() => {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollTop = scrollTop;
|
||||
requestAnimationFrame(() => {
|
||||
requestAnimationFrame(() => {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollTop = scrollTop;
|
||||
scrollContainer.style.overflow = "";
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}, 450);
|
||||
}
|
||||
|
||||
setPreferredIndex(modelIndex);
|
||||
const response = responses.find((r) => r.modelIndex === modelIndex);
|
||||
if (!response) return;
|
||||
if (onMessageSelection) {
|
||||
onMessageSelection(response.nodeId);
|
||||
}
|
||||
|
||||
// Persist preferred response + sync `latestChildNodeId`. Backend's
|
||||
// `set_preferred_response` updates `latest_child_message_id`; if the
|
||||
// frontend chain walk disagrees, the next follow-up fails with
|
||||
// "not on the latest mainline".
|
||||
// Persist preferred response to backend + update local tree so the
|
||||
// input bar unblocks (awaitingPreferredSelection clears).
|
||||
if (parentMessage?.messageId && response.messageId && currentSessionId) {
|
||||
setPreferredResponse(parentMessage.messageId, response.messageId).catch(
|
||||
(err) => console.error("Failed to persist preferred response:", err)
|
||||
@@ -229,7 +176,6 @@ export default function MultiModelResponseView({
|
||||
updated.set(parentMessage.nodeId, {
|
||||
...userMsg,
|
||||
preferredResponseId: response.messageId,
|
||||
latestChildNodeId: response.nodeId,
|
||||
});
|
||||
updateSessionMessageTree(currentSessionId, updated);
|
||||
}
|
||||
@@ -239,111 +185,17 @@ export default function MultiModelResponseView({
|
||||
[
|
||||
isGenerating,
|
||||
responses,
|
||||
preferredIndex,
|
||||
onMessageSelection,
|
||||
parentMessage,
|
||||
currentSessionId,
|
||||
updateSessionMessageTree,
|
||||
]
|
||||
);
|
||||
|
||||
// NOTE: Deselect only clears the local tree — no backend call to clear
|
||||
// preferred_response_id. The SetPreferredResponseRequest model doesn't
|
||||
// accept null. A backend endpoint for clearing preference would be needed
|
||||
// if deselect should persist across reloads.
|
||||
const handleDeselectPreferred = useCallback(() => {
|
||||
const scrollContainer = trackContainerElRef.current?.closest(
|
||||
"[data-chat-scroll]"
|
||||
) as HTMLElement | null;
|
||||
|
||||
// Animate panels back to equal positions, then clear preferred after transition
|
||||
setSelectionExiting(true);
|
||||
setSelectionEntered(false);
|
||||
deselectTimeoutRef.current = setTimeout(() => {
|
||||
deselectTimeoutRef.current = null;
|
||||
const scrollTop = scrollContainer?.scrollTop ?? 0;
|
||||
if (scrollContainer) scrollContainer.style.overflow = "hidden";
|
||||
|
||||
// Before clearing state, animate each capped panel's height from
|
||||
// its current clientHeight to its natural scrollHeight.
|
||||
const animations: Animation[] = [];
|
||||
panelElsRef.current.forEach((el, modelIndex) => {
|
||||
if (modelIndex === preferredIndex) return;
|
||||
if (hiddenPanels.has(modelIndex)) return;
|
||||
const from = el.clientHeight;
|
||||
const to = el.scrollHeight;
|
||||
if (to <= from) return;
|
||||
// Lock current height, remove maxHeight cap, then animate
|
||||
el.style.maxHeight = `${from}px`;
|
||||
el.style.overflow = "hidden";
|
||||
const anim = el.animate(
|
||||
[{ maxHeight: `${from}px` }, { maxHeight: `${to}px` }],
|
||||
{
|
||||
duration: 350,
|
||||
easing: "cubic-bezier(0.2, 0, 0, 1)",
|
||||
fill: "forwards",
|
||||
}
|
||||
);
|
||||
animations.push(anim);
|
||||
anim.onfinish = () => {
|
||||
el.style.maxHeight = "";
|
||||
el.style.overflow = "";
|
||||
};
|
||||
});
|
||||
|
||||
setSelectionExiting(false);
|
||||
setPreferredIndex(null);
|
||||
|
||||
// Restore scroll after animations + React settle
|
||||
const restoreScroll = () => {
|
||||
requestAnimationFrame(() => {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollTop = scrollTop;
|
||||
scrollContainer.style.overflow = "";
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
if (animations.length > 0) {
|
||||
Promise.all(animations.map((a) => a.finished))
|
||||
.then(restoreScroll)
|
||||
.catch(restoreScroll);
|
||||
} else {
|
||||
restoreScroll();
|
||||
}
|
||||
|
||||
// Clear preferredResponseId in the local tree so input bar re-gates
|
||||
if (parentMessage && currentSessionId) {
|
||||
const tree = useChatSessionStore
|
||||
.getState()
|
||||
.sessions.get(currentSessionId)?.messageTree;
|
||||
if (tree) {
|
||||
const userMsg = tree.get(parentMessage.nodeId);
|
||||
if (userMsg) {
|
||||
const updated = new Map(tree);
|
||||
updated.set(parentMessage.nodeId, {
|
||||
...userMsg,
|
||||
preferredResponseId: undefined,
|
||||
});
|
||||
updateSessionMessageTree(currentSessionId, updated);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 450);
|
||||
}, [
|
||||
parentMessage,
|
||||
currentSessionId,
|
||||
updateSessionMessageTree,
|
||||
preferredIndex,
|
||||
hiddenPanels,
|
||||
]);
|
||||
|
||||
// Clear preferred selection when generation starts
|
||||
// Reset selection state when generation restarts
|
||||
useEffect(() => {
|
||||
if (isGenerating) {
|
||||
setPreferredIndex(null);
|
||||
setHasEnteredSelection(false);
|
||||
setSelectionExiting(false);
|
||||
}
|
||||
}, [isGenerating]);
|
||||
|
||||
@@ -352,39 +204,22 @@ export default function MultiModelResponseView({
|
||||
(r) => r.modelIndex === preferredIndex
|
||||
);
|
||||
|
||||
// Track whether selection mode was ever entered — once it has been,
|
||||
// we stay in the selection layout (even after deselect) to avoid a
|
||||
// jarring DOM swap between the two layout strategies.
|
||||
const [hasEnteredSelection, setHasEnteredSelection] = useState(
|
||||
() => preferredIndex !== null
|
||||
);
|
||||
|
||||
const isActivelySelected =
|
||||
// Selection mode when preferred is set, found in responses, not generating, and at least 2 visible panels
|
||||
const showSelectionMode =
|
||||
preferredIndex !== null &&
|
||||
preferredIdx !== -1 &&
|
||||
!isGenerating &&
|
||||
visibleResponses.length > 1;
|
||||
|
||||
// Trigger the slide-out animation one frame after entering selection mode
|
||||
useEffect(() => {
|
||||
if (isActivelySelected) setHasEnteredSelection(true);
|
||||
}, [isActivelySelected]);
|
||||
|
||||
// Use the selection layout once a preferred response has been chosen,
|
||||
// even after deselect. Only fall through to generation layout before
|
||||
// the first selection or during active streaming.
|
||||
const showSelectionMode = isActivelySelected || hasEnteredSelection;
|
||||
|
||||
// Trigger the slide-out animation one frame after a preferred panel is selected.
|
||||
// Uses isActivelySelected (not showSelectionMode) so re-selecting after a
|
||||
// deselect still triggers the animation.
|
||||
useEffect(() => {
|
||||
if (!isActivelySelected) {
|
||||
// Don't reset selectionEntered here — handleDeselectPreferred manages it
|
||||
if (!showSelectionMode) {
|
||||
setSelectionEntered(false);
|
||||
return;
|
||||
}
|
||||
const raf = requestAnimationFrame(() => setSelectionEntered(true));
|
||||
return () => cancelAnimationFrame(raf);
|
||||
}, [isActivelySelected]);
|
||||
}, [showSelectionMode]);
|
||||
|
||||
// Build panel props — isHidden reflects actual hidden state
|
||||
const buildPanelProps = useCallback(
|
||||
@@ -396,7 +231,6 @@ export default function MultiModelResponseView({
|
||||
isHidden: hiddenPanels.has(response.modelIndex),
|
||||
isNonPreferredInSelection: isNonPreferred,
|
||||
onSelect: () => handleSelectPreferred(response.modelIndex),
|
||||
onDeselect: handleDeselectPreferred,
|
||||
onToggleVisibility: () => toggleVisibility(response.modelIndex),
|
||||
agentMessageProps: {
|
||||
rawPackets: response.packets,
|
||||
@@ -421,7 +255,6 @@ export default function MultiModelResponseView({
|
||||
preferredIndex,
|
||||
hiddenPanels,
|
||||
handleSelectPreferred,
|
||||
handleDeselectPreferred,
|
||||
toggleVisibility,
|
||||
chatState,
|
||||
llmManager,
|
||||
@@ -477,30 +310,25 @@ export default function MultiModelResponseView({
|
||||
<div
|
||||
ref={trackContainerRef}
|
||||
className="w-full overflow-hidden"
|
||||
style={
|
||||
isActivelySelected
|
||||
? {
|
||||
maskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
WebkitMaskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
style={{
|
||||
maskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
WebkitMaskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="flex items-start"
|
||||
style={{
|
||||
gap: `${PANEL_GAP}px`,
|
||||
transition:
|
||||
selectionEntered || selectionExiting
|
||||
? "transform 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
transition: selectionEntered
|
||||
? "transform 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
transform: trackTransform,
|
||||
}}
|
||||
>
|
||||
{responses.map((r, i) => {
|
||||
const isHidden = hiddenPanels.has(r.modelIndex);
|
||||
const isPref = r.modelIndex === preferredIndex;
|
||||
const isNonPref = !isHidden && !isPref && preferredIndex !== null;
|
||||
const isNonPref = !isHidden && !isPref;
|
||||
const finalW = selectionWidths[i]!;
|
||||
const startW = isHidden ? HIDDEN_PANEL_W : SELECTION_PANEL_W;
|
||||
const capped = isNonPref && preferredPanelHeight != null;
|
||||
@@ -509,11 +337,6 @@ export default function MultiModelResponseView({
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
ref={(el) => {
|
||||
if (el) {
|
||||
panelElsRef.current.set(r.modelIndex, el);
|
||||
} else {
|
||||
panelElsRef.current.delete(r.modelIndex);
|
||||
}
|
||||
if (isPref) preferredPanelRef(el);
|
||||
if (capped && el) {
|
||||
const doesOverflow = el.scrollHeight > el.clientHeight;
|
||||
@@ -530,10 +353,9 @@ export default function MultiModelResponseView({
|
||||
style={{
|
||||
width: `${selectionEntered ? finalW : startW}px`,
|
||||
flexShrink: 0,
|
||||
transition:
|
||||
selectionEntered || selectionExiting
|
||||
? "width 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
transition: selectionEntered
|
||||
? "width 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
maxHeight: capped ? preferredPanelHeight : undefined,
|
||||
overflow: capped ? "hidden" : undefined,
|
||||
position: capped ? "relative" : undefined,
|
||||
@@ -566,7 +388,7 @@ export default function MultiModelResponseView({
|
||||
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<div className="flex gap-6 items-start justify-center w-full">
|
||||
<div className="flex gap-6 items-start w-full">
|
||||
{responses.map((r) => {
|
||||
const isHidden = hiddenPanels.has(r.modelIndex);
|
||||
return (
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
isRecommendedModel,
|
||||
} from "@/app/craft/onboarding/constants";
|
||||
import { ToggleWarningModal } from "./ToggleWarningModal";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import {
|
||||
Accordion,
|
||||
|
||||
@@ -48,7 +48,7 @@ import NotAllowedModal from "@/app/craft/onboarding/components/NotAllowedModal";
|
||||
import { useOnboarding } from "@/app/craft/onboarding/BuildOnboardingProvider";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import {
|
||||
getBuildUserPersona,
|
||||
getPersonaInfo,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
:root {
|
||||
--app-page-main-content-width: 45rem;
|
||||
--app-page-main-content-width: 52.5rem;
|
||||
--block-width-form-input-min: 10rem;
|
||||
|
||||
--container-sm: 42rem;
|
||||
|
||||
@@ -45,9 +45,6 @@ import { personaIncludesRetrieval } from "@/app/app/services/lib";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { eeGated } from "@/ce";
|
||||
import EESearchUI from "@/ee/sections/SearchUI";
|
||||
import useMultiModelChat from "@/hooks/useMultiModelChat";
|
||||
import ModelSelector from "@/refresh-components/popovers/ModelSelector";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
const SearchUI = eeGated(EESearchUI);
|
||||
|
||||
@@ -108,20 +105,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
// If no LLM provider is configured (e.g., fresh signup), the input bar is
|
||||
// disabled and a "Set up an LLM" button is shown (see bottom of component).
|
||||
const llmManager = useLlmManager(undefined, liveAgent ?? undefined);
|
||||
const multiModel = useMultiModelChat(llmManager);
|
||||
|
||||
// Sync single-model selection to llmManager so the submission path
|
||||
// uses the correct provider/version (mirrors AppPage behaviour).
|
||||
useEffect(() => {
|
||||
if (multiModel.selectedModels.length === 1) {
|
||||
const model = multiModel.selectedModels[0]!;
|
||||
llmManager.updateCurrentLlm({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
modelName: model.modelName,
|
||||
});
|
||||
}
|
||||
}, [multiModel.selectedModels]);
|
||||
|
||||
// Deep research toggle
|
||||
const { deepResearchEnabled, toggleDeepResearch } = useDeepResearchToggle({
|
||||
@@ -312,17 +295,12 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
|
||||
// If we already have messages (chat session started), always use chat mode
|
||||
// (matches AppPage behavior where existing sessions bypass classification)
|
||||
const selectedModels = multiModel.isMultiModelActive
|
||||
? multiModel.selectedModels
|
||||
: undefined;
|
||||
|
||||
if (hasMessages) {
|
||||
onSubmit({
|
||||
message: submittedMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
});
|
||||
return;
|
||||
}
|
||||
@@ -334,7 +312,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -351,8 +328,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
submitQuery,
|
||||
tabReadingEnabled,
|
||||
currentTabUrl,
|
||||
multiModel.isMultiModelActive,
|
||||
multiModel.selectedModels,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -481,7 +456,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onResubmit={handleResubmitLastMessage}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
anchorNodeId={anchorNodeId}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
/>
|
||||
</ChatScrollContainer>
|
||||
</>
|
||||
@@ -490,23 +464,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
{/* Welcome message - centered when no messages and not in search mode */}
|
||||
{!hasMessages && !isSearch && (
|
||||
<div className="relative w-full flex-1 flex flex-col items-center justify-end">
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="end"
|
||||
className="max-w-[var(--app-page-main-content-width)]"
|
||||
>
|
||||
<WelcomeMessage isDefaultAgent />
|
||||
{liveAgent && !llmManager.isLoadingProviders && (
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
)}
|
||||
</Section>
|
||||
<WelcomeMessage isDefaultAgent />
|
||||
<Spacer rem={1.5} />
|
||||
</div>
|
||||
)}
|
||||
@@ -520,17 +478,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
"max-w-[var(--app-page-main-content-width)] px-4"
|
||||
)}
|
||||
>
|
||||
{hasMessages && liveAgent && !llmManager.isLoadingProviders && (
|
||||
<div className="pb-1">
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<AppInputBar
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { useMemo } from "react";
|
||||
import { parseLlmDescriptor, structureValue } from "@/lib/llmConfig/utils";
|
||||
import { DefaultModel, LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { createIcon } from "@/components/icons/icons";
|
||||
|
||||
|
||||
@@ -694,25 +694,6 @@ export function useLlmManager(
|
||||
prevAgentIdRef.current = liveAgent?.id;
|
||||
}, [liveAgent?.id]);
|
||||
|
||||
// Clear manual override when arriving at a *different* existing session
|
||||
// from any previously-seen defined session. Tracks only the last
|
||||
// *defined* session id so a round-trip through new-chat (A → undefined
|
||||
// → B) still resets, while A → undefined (new-chat) preserves it.
|
||||
const prevDefinedSessionIdRef = useRef<string | undefined>(undefined);
|
||||
useEffect(() => {
|
||||
const nextId = currentChatSession?.id;
|
||||
if (
|
||||
nextId !== undefined &&
|
||||
prevDefinedSessionIdRef.current !== undefined &&
|
||||
nextId !== prevDefinedSessionIdRef.current
|
||||
) {
|
||||
setUserHasManuallyOverriddenLLM(false);
|
||||
}
|
||||
if (nextId !== undefined) {
|
||||
prevDefinedSessionIdRef.current = nextId;
|
||||
}
|
||||
}, [currentChatSession?.id]);
|
||||
|
||||
function getValidLlmDescriptor(
|
||||
modelName: string | null | undefined
|
||||
): LlmDescriptor {
|
||||
@@ -734,9 +715,8 @@ export function useLlmManager(
|
||||
|
||||
if (llmProviders === undefined || llmProviders === null) {
|
||||
resolved = manualLlm;
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
// Manual override wins over session's `current_alternate_model`.
|
||||
// Cleared on cross-session navigation by the effect above.
|
||||
} else if (userHasManuallyOverriddenLLM && !currentChatSession) {
|
||||
// User has overridden in this session and switched to a new session
|
||||
resolved = manualLlm;
|
||||
} else if (currentChatSession?.current_alternate_model) {
|
||||
resolved = getValidLlmDescriptorForProviders(
|
||||
@@ -748,6 +728,8 @@ export function useLlmManager(
|
||||
liveAgent.llm_model_version_override,
|
||||
llmProviders
|
||||
);
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
resolved = manualLlm;
|
||||
} else if (user?.preferences?.default_model) {
|
||||
resolved = getValidLlmDescriptorForProviders(
|
||||
user.preferences.default_model,
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { SvgCpu, SvgPlug, SvgServer } from "@opal/icons";
|
||||
import {
|
||||
SvgBifrost,
|
||||
SvgOpenai,
|
||||
SvgClaude,
|
||||
SvgOllama,
|
||||
SvgAws,
|
||||
SvgOpenrouter,
|
||||
SvgAzure,
|
||||
SvgGemini,
|
||||
SvgLitellm,
|
||||
SvgLmStudio,
|
||||
SvgMicrosoft,
|
||||
SvgMistral,
|
||||
SvgDeepseek,
|
||||
SvgQwen,
|
||||
SvgGoogle,
|
||||
} from "@opal/logos";
|
||||
import { ZAIIcon } from "@/components/icons/icons";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import type { LLMProviderView } from "@/interfaces/llm";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import AnthropicModal from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import OllamaModal from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import AzureModal from "@/sections/modals/llmConfig/AzureModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
// ─── Text (LLM) providers ────────────────────────────────────────────────────
|
||||
|
||||
export interface ProviderEntry {
|
||||
icon: IconFunctionComponent;
|
||||
productName: string;
|
||||
companyName: string;
|
||||
Modal: React.ComponentType<LLMProviderFormProps>;
|
||||
}
|
||||
|
||||
const PROVIDERS: Record<string, ProviderEntry> = {
|
||||
[LLMProviderName.OPENAI]: {
|
||||
icon: SvgOpenai,
|
||||
productName: "GPT",
|
||||
companyName: "OpenAI",
|
||||
Modal: OpenAIModal,
|
||||
},
|
||||
[LLMProviderName.ANTHROPIC]: {
|
||||
icon: SvgClaude,
|
||||
productName: "Claude",
|
||||
companyName: "Anthropic",
|
||||
Modal: AnthropicModal,
|
||||
},
|
||||
[LLMProviderName.VERTEX_AI]: {
|
||||
icon: SvgGemini,
|
||||
productName: "Gemini",
|
||||
companyName: "Google Cloud Vertex AI",
|
||||
Modal: VertexAIModal,
|
||||
},
|
||||
[LLMProviderName.BEDROCK]: {
|
||||
icon: SvgAws,
|
||||
productName: "Amazon Bedrock",
|
||||
companyName: "AWS",
|
||||
Modal: BedrockModal,
|
||||
},
|
||||
[LLMProviderName.AZURE]: {
|
||||
icon: SvgAzure,
|
||||
productName: "Azure OpenAI",
|
||||
companyName: "Microsoft Azure",
|
||||
Modal: AzureModal,
|
||||
},
|
||||
[LLMProviderName.LITELLM]: {
|
||||
icon: SvgLitellm,
|
||||
productName: "LiteLLM",
|
||||
companyName: "LiteLLM",
|
||||
Modal: CustomModal,
|
||||
},
|
||||
[LLMProviderName.LITELLM_PROXY]: {
|
||||
icon: SvgLitellm,
|
||||
productName: "LiteLLM Proxy",
|
||||
companyName: "LiteLLM Proxy",
|
||||
Modal: LiteLLMProxyModal,
|
||||
},
|
||||
[LLMProviderName.OLLAMA_CHAT]: {
|
||||
icon: SvgOllama,
|
||||
productName: "Ollama",
|
||||
companyName: "Ollama",
|
||||
Modal: OllamaModal,
|
||||
},
|
||||
[LLMProviderName.OPENROUTER]: {
|
||||
icon: SvgOpenrouter,
|
||||
productName: "OpenRouter",
|
||||
companyName: "OpenRouter",
|
||||
Modal: OpenRouterModal,
|
||||
},
|
||||
[LLMProviderName.LM_STUDIO]: {
|
||||
icon: SvgLmStudio,
|
||||
productName: "LM Studio",
|
||||
companyName: "LM Studio",
|
||||
Modal: LMStudioModal,
|
||||
},
|
||||
[LLMProviderName.BIFROST]: {
|
||||
icon: SvgBifrost,
|
||||
productName: "Bifrost",
|
||||
companyName: "Bifrost",
|
||||
Modal: BifrostModal,
|
||||
},
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: {
|
||||
icon: SvgPlug,
|
||||
productName: "OpenAI-Compatible",
|
||||
companyName: "OpenAI-Compatible",
|
||||
Modal: OpenAICompatibleModal,
|
||||
},
|
||||
[LLMProviderName.CUSTOM]: {
|
||||
icon: SvgServer,
|
||||
productName: "Custom Models",
|
||||
companyName: "models from other LiteLLM-compatible providers",
|
||||
Modal: CustomModal,
|
||||
},
|
||||
};
|
||||
|
||||
const DEFAULT_ENTRY: ProviderEntry = {
|
||||
icon: SvgCpu,
|
||||
productName: "",
|
||||
companyName: "",
|
||||
Modal: CustomModal,
|
||||
};
|
||||
|
||||
// Providers that don't use custom_config themselves — if custom_config is
|
||||
// present it means the provider was originally created via CustomModal.
|
||||
const CUSTOM_CONFIG_OVERRIDES = new Set<string>([
|
||||
LLMProviderName.OPENAI,
|
||||
LLMProviderName.ANTHROPIC,
|
||||
LLMProviderName.AZURE,
|
||||
LLMProviderName.OPENROUTER,
|
||||
]);
|
||||
|
||||
export function getProvider(
|
||||
providerName: string,
|
||||
existingProvider?: LLMProviderView
|
||||
): ProviderEntry {
|
||||
const entry = PROVIDERS[providerName] ?? {
|
||||
...DEFAULT_ENTRY,
|
||||
productName: providerName,
|
||||
companyName: providerName,
|
||||
};
|
||||
|
||||
if (
|
||||
existingProvider?.custom_config != null &&
|
||||
CUSTOM_CONFIG_OVERRIDES.has(providerName)
|
||||
) {
|
||||
return { ...entry, Modal: CustomModal };
|
||||
}
|
||||
|
||||
return entry;
|
||||
}
|
||||
|
||||
// ─── Aggregator providers ────────────────────────────────────────────────────
|
||||
// Providers that host models from multiple vendors (e.g. Bedrock hosts Claude,
|
||||
// Llama, etc.) Used by the model-icon resolver to prioritise vendor icons.
|
||||
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
LLMProviderName.BEDROCK,
|
||||
"bedrock_converse",
|
||||
LLMProviderName.OPENROUTER,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
LLMProviderName.BIFROST,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
]);
|
||||
|
||||
// ─── Model-aware icon resolver ───────────────────────────────────────────────
|
||||
|
||||
const MODEL_ICON_MAP: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENAI]: SvgOpenai,
|
||||
[LLMProviderName.ANTHROPIC]: SvgClaude,
|
||||
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.VERTEX_AI]: SvgGemini,
|
||||
[LLMProviderName.BEDROCK]: SvgAws,
|
||||
[LLMProviderName.LITELLM_PROXY]: SvgLitellm,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
amazon: SvgAws,
|
||||
phi: SvgMicrosoft,
|
||||
mistral: SvgMistral,
|
||||
ministral: SvgMistral,
|
||||
llama: SvgCpu,
|
||||
ollama: SvgOllama,
|
||||
gemini: SvgGemini,
|
||||
deepseek: SvgDeepseek,
|
||||
claude: SvgClaude,
|
||||
azure: SvgAzure,
|
||||
microsoft: SvgMicrosoft,
|
||||
meta: SvgCpu,
|
||||
google: SvgGoogle,
|
||||
qwen: SvgQwen,
|
||||
qwq: SvgQwen,
|
||||
zai: ZAIIcon,
|
||||
bedrock_converse: SvgAws,
|
||||
};
|
||||
|
||||
/**
|
||||
* Model-aware icon resolver that checks both provider name and model name
|
||||
* to pick the most specific icon (e.g. Claude icon for a Bedrock Claude model).
|
||||
*/
|
||||
export function getModelIcon(
|
||||
providerName: string,
|
||||
modelName?: string
|
||||
): IconFunctionComponent {
|
||||
const lowerProviderName = providerName.toLowerCase();
|
||||
|
||||
// For aggregator providers, prioritise showing the vendor icon based on model name
|
||||
if (AGGREGATOR_PROVIDERS.has(lowerProviderName) && modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(MODEL_ICON_MAP)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if provider name directly matches an icon
|
||||
if (lowerProviderName in MODEL_ICON_MAP) {
|
||||
const icon = MODEL_ICON_MAP[lowerProviderName];
|
||||
if (icon) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
|
||||
// For non-aggregator providers, check if model name contains any of the keys
|
||||
if (modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(MODEL_ICON_MAP)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to CPU icon if no matches
|
||||
return SvgCpu;
|
||||
}
|
||||
176
web/src/lib/llmConfig/providers.ts
Normal file
176
web/src/lib/llmConfig/providers.ts
Normal file
@@ -0,0 +1,176 @@
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { SvgCpu, SvgPlug, SvgServer } from "@opal/icons";
|
||||
import {
|
||||
SvgBifrost,
|
||||
SvgOpenai,
|
||||
SvgClaude,
|
||||
SvgOllama,
|
||||
SvgAws,
|
||||
SvgOpenrouter,
|
||||
SvgAzure,
|
||||
SvgGemini,
|
||||
SvgLitellm,
|
||||
SvgLmStudio,
|
||||
SvgMicrosoft,
|
||||
SvgMistral,
|
||||
SvgDeepseek,
|
||||
SvgQwen,
|
||||
SvgGoogle,
|
||||
} from "@opal/logos";
|
||||
import { ZAIIcon } from "@/components/icons/icons";
|
||||
import { LLMProviderName } from "@/interfaces/llm";
|
||||
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
LLMProviderName.BEDROCK,
|
||||
"bedrock_converse",
|
||||
LLMProviderName.OPENROUTER,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
LLMProviderName.BIFROST,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
]);
|
||||
|
||||
const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENAI]: SvgOpenai,
|
||||
[LLMProviderName.ANTHROPIC]: SvgClaude,
|
||||
[LLMProviderName.VERTEX_AI]: SvgGemini,
|
||||
[LLMProviderName.BEDROCK]: SvgAws,
|
||||
[LLMProviderName.AZURE]: SvgAzure,
|
||||
[LLMProviderName.LITELLM]: SvgLitellm,
|
||||
[LLMProviderName.LITELLM_PROXY]: SvgLitellm,
|
||||
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: SvgServer,
|
||||
};
|
||||
|
||||
const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENAI]: "GPT",
|
||||
[LLMProviderName.ANTHROPIC]: "Claude",
|
||||
[LLMProviderName.VERTEX_AI]: "Gemini",
|
||||
[LLMProviderName.BEDROCK]: "Amazon Bedrock",
|
||||
[LLMProviderName.AZURE]: "Azure OpenAI",
|
||||
[LLMProviderName.LITELLM]: "LiteLLM",
|
||||
[LLMProviderName.LITELLM_PROXY]: "LiteLLM Proxy",
|
||||
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI-Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Custom Models",
|
||||
};
|
||||
|
||||
const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENAI]: "OpenAI",
|
||||
[LLMProviderName.ANTHROPIC]: "Anthropic",
|
||||
[LLMProviderName.VERTEX_AI]: "Google Cloud Vertex AI",
|
||||
[LLMProviderName.BEDROCK]: "AWS",
|
||||
[LLMProviderName.AZURE]: "Microsoft Azure",
|
||||
[LLMProviderName.LITELLM]: "LiteLLM",
|
||||
[LLMProviderName.LITELLM_PROXY]: "LiteLLM Proxy",
|
||||
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI-Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "models from other LiteLLM-compatible providers",
|
||||
};
|
||||
|
||||
export function getProviderProductName(providerName: string): string {
|
||||
return PROVIDER_PRODUCT_NAMES[providerName] ?? providerName;
|
||||
}
|
||||
|
||||
export function getProviderDisplayName(providerName: string): string {
|
||||
return PROVIDER_DISPLAY_NAMES[providerName] ?? providerName;
|
||||
}
|
||||
|
||||
export function getProviderIcon(providerName: string): IconFunctionComponent {
|
||||
return PROVIDER_ICONS[providerName] ?? SvgCpu;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model-aware icon resolver (legacy icon set)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const MODEL_ICON_MAP: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENAI]: SvgOpenai,
|
||||
[LLMProviderName.ANTHROPIC]: SvgClaude,
|
||||
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.VERTEX_AI]: SvgGemini,
|
||||
[LLMProviderName.BEDROCK]: SvgAws,
|
||||
[LLMProviderName.LITELLM_PROXY]: SvgLitellm,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
amazon: SvgAws,
|
||||
phi: SvgMicrosoft,
|
||||
mistral: SvgMistral,
|
||||
ministral: SvgMistral,
|
||||
llama: SvgCpu,
|
||||
ollama: SvgOllama,
|
||||
gemini: SvgGemini,
|
||||
deepseek: SvgDeepseek,
|
||||
claude: SvgClaude,
|
||||
azure: SvgAzure,
|
||||
microsoft: SvgMicrosoft,
|
||||
meta: SvgCpu,
|
||||
google: SvgGoogle,
|
||||
qwen: SvgQwen,
|
||||
qwq: SvgQwen,
|
||||
zai: ZAIIcon,
|
||||
bedrock_converse: SvgAws,
|
||||
};
|
||||
|
||||
/**
|
||||
* Model-aware icon resolver that checks both provider name and model name
|
||||
* to pick the most specific icon (e.g. Claude icon for a Bedrock Claude model).
|
||||
*/
|
||||
export const getModelIcon = (
|
||||
providerName: string,
|
||||
modelName?: string
|
||||
): IconFunctionComponent => {
|
||||
const lowerProviderName = providerName.toLowerCase();
|
||||
|
||||
// For aggregator providers, prioritise showing the vendor icon based on model name
|
||||
if (AGGREGATOR_PROVIDERS.has(lowerProviderName) && modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(MODEL_ICON_MAP)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if provider name directly matches an icon
|
||||
if (lowerProviderName in MODEL_ICON_MAP) {
|
||||
const icon = MODEL_ICON_MAP[lowerProviderName];
|
||||
if (icon) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
|
||||
// For non-aggregator providers, check if model name contains any of the keys
|
||||
if (modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(MODEL_ICON_MAP)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to CPU icon if no matches
|
||||
return SvgCpu;
|
||||
};
|
||||
@@ -44,7 +44,7 @@ export function getFinalLLM(
|
||||
return [provider, model];
|
||||
}
|
||||
|
||||
export function getProviderOverrideForPersona(
|
||||
export function getLLMProviderOverrideForPersona(
|
||||
liveAgent: MinimalPersonaSnapshot,
|
||||
llmProviders: LLMProviderDescriptor[]
|
||||
): LlmDescriptor | null {
|
||||
@@ -144,7 +144,7 @@ export function getDisplayName(
|
||||
agent: MinimalPersonaSnapshot,
|
||||
llmProviders: LLMProviderDescriptor[]
|
||||
): string | undefined {
|
||||
const llmDescriptor = getProviderOverrideForPersona(
|
||||
const llmDescriptor = getLLMProviderOverrideForPersona(
|
||||
agent,
|
||||
llmProviders ?? []
|
||||
);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import copy from "copy-to-clipboard";
|
||||
import { Button, ButtonProps } from "@opal/components";
|
||||
import { SvgAlertTriangle, SvgCheck, SvgCopy } from "@opal/icons";
|
||||
|
||||
@@ -41,19 +40,26 @@ export default function CopyIconButton({
|
||||
}
|
||||
|
||||
try {
|
||||
if (navigator.clipboard && getHtmlContent) {
|
||||
// Check if Clipboard API is available
|
||||
if (!navigator.clipboard) {
|
||||
throw new Error("Clipboard API not available");
|
||||
}
|
||||
|
||||
// If HTML content getter is provided, copy both HTML and plain text
|
||||
if (getHtmlContent) {
|
||||
const htmlContent = getHtmlContent();
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob([htmlContent], { type: "text/html" }),
|
||||
"text/plain": new Blob([text], { type: "text/plain" }),
|
||||
});
|
||||
await navigator.clipboard.write([clipboardItem]);
|
||||
} else if (navigator.clipboard) {
|
||||
}
|
||||
// Default: plain text only
|
||||
else {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} else if (!copy(text)) {
|
||||
throw new Error("copy-to-clipboard returned false");
|
||||
}
|
||||
|
||||
// Show "copied" state
|
||||
setCopyState("copied");
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useState, useEffect, useCallback, useMemo, useRef } from "react";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { structureValue } from "@/lib/llmConfig/utils";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import { AGGREGATOR_PROVIDERS } from "@/lib/llmConfig/svc";
|
||||
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { useState, useMemo, useRef } from "react";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { getModelIcon } from "@/lib/llmConfig";
|
||||
import { getModelIcon } from "@/lib/llmConfig/providers";
|
||||
import { Button, SelectButton, OpenButton } from "@opal/components";
|
||||
import { SvgPlusCircle, SvgX } from "@opal/icons";
|
||||
import { LLMOption } from "@/refresh-components/popovers/interfaces";
|
||||
@@ -104,7 +104,6 @@ export default function ModelSelector({
|
||||
onRemove(existingIndex);
|
||||
} else if (!atMax) {
|
||||
onAdd(model);
|
||||
setOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -159,12 +158,9 @@ export default function ModelSelector({
|
||||
);
|
||||
|
||||
if (!isMultiModel) {
|
||||
// Stable key — keying on model would unmount the pill
|
||||
// on change and leave Radix's anchorRef detached,
|
||||
// flashing the closing popover at (0,0).
|
||||
return (
|
||||
<OpenButton
|
||||
key="single-model-pill"
|
||||
key={modelKey(model.provider, model.modelName)}
|
||||
icon={ProviderIcon}
|
||||
onClick={(e: React.MouseEvent) =>
|
||||
handlePillClick(index, e.currentTarget as HTMLElement)
|
||||
@@ -218,17 +214,15 @@ export default function ModelSelector({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{!(atMax && replacingIndex === null) && (
|
||||
<Popover.Content side="top" align="end" width="lg">
|
||||
<ModelListContent
|
||||
llmProviders={llmManager.llmProviders}
|
||||
isLoading={llmManager.isLoadingProviders}
|
||||
onSelect={handleSelect}
|
||||
isSelected={isSelected}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Popover.Content>
|
||||
)}
|
||||
<Popover.Content side="top" align="end" width="lg">
|
||||
<ModelListContent
|
||||
llmProviders={llmManager.llmProviders}
|
||||
isLoading={llmManager.isLoadingProviders}
|
||||
onSelect={handleSelect}
|
||||
isSelected={isSelected}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -400,22 +400,19 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const multiModel = useMultiModelChat(llmManager);
|
||||
|
||||
// Auto-fold sidebar when a multi-model message is submitted.
|
||||
// Stays collapsed until the user exits multi-model mode (removes models).
|
||||
// Auto-fold sidebar when multi-model is active (panels need full width)
|
||||
const { folded: sidebarFolded, setFolded: setSidebarFolded } =
|
||||
useSidebarState();
|
||||
const preMultiModelFoldedRef = useRef<boolean | null>(null);
|
||||
|
||||
const foldSidebarForMultiModel = useCallback(() => {
|
||||
if (preMultiModelFoldedRef.current === null) {
|
||||
preMultiModelFoldedRef.current = sidebarFolded;
|
||||
setSidebarFolded(true);
|
||||
}
|
||||
}, [sidebarFolded, setSidebarFolded]);
|
||||
|
||||
// Restore sidebar when user exits multi-model mode
|
||||
useEffect(() => {
|
||||
if (
|
||||
multiModel.isMultiModelActive &&
|
||||
preMultiModelFoldedRef.current === null
|
||||
) {
|
||||
preMultiModelFoldedRef.current = sidebarFolded;
|
||||
setSidebarFolded(true);
|
||||
} else if (
|
||||
!multiModel.isMultiModelActive &&
|
||||
preMultiModelFoldedRef.current !== null
|
||||
) {
|
||||
@@ -425,27 +422,16 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [multiModel.isMultiModelActive]);
|
||||
|
||||
// Sync single-model selection to llmManager so the submission path uses
|
||||
// the correct provider/version. Guard against echoing derived state back
|
||||
// — only call updateCurrentLlm when the selection actually differs from
|
||||
// currentLlm, otherwise the initial [] → [currentLlmModel] sync would
|
||||
// pin `userHasManuallyOverriddenLLM=true` with whatever was resolved
|
||||
// first (often the default model before the session's alt_model loads).
|
||||
// Sync single-model selection to llmManager so the submission path
|
||||
// uses the correct provider/version (replaces the old LLMPopover sync).
|
||||
useEffect(() => {
|
||||
if (multiModel.selectedModels.length === 1) {
|
||||
const model = multiModel.selectedModels[0]!;
|
||||
const current = llmManager.currentLlm;
|
||||
if (
|
||||
model.provider !== current.provider ||
|
||||
model.modelName !== current.modelName ||
|
||||
model.name !== current.name
|
||||
) {
|
||||
llmManager.updateCurrentLlm({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
modelName: model.modelName,
|
||||
});
|
||||
}
|
||||
llmManager.updateCurrentLlm({
|
||||
name: model.name,
|
||||
provider: model.provider,
|
||||
modelName: model.modelName,
|
||||
});
|
||||
}
|
||||
}, [multiModel.selectedModels]);
|
||||
|
||||
@@ -546,9 +532,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
|
||||
const onChat = useCallback(
|
||||
(message: string) => {
|
||||
if (multiModel.isMultiModelActive) {
|
||||
foldSidebarForMultiModel();
|
||||
}
|
||||
resetInputBar();
|
||||
onSubmit({
|
||||
message,
|
||||
@@ -569,7 +552,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
deepResearchEnabledForCurrentWorkflow,
|
||||
multiModel.isMultiModelActive,
|
||||
multiModel.selectedModels,
|
||||
foldSidebarForMultiModel,
|
||||
showOnboarding,
|
||||
onboardingDismissed,
|
||||
finishOnboarding,
|
||||
@@ -882,20 +864,13 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
agent={liveAgent}
|
||||
isDefaultAgent={isDefaultAgent}
|
||||
/>
|
||||
{!isSearch &&
|
||||
!(
|
||||
state.phase === "idle" && state.appMode === "search"
|
||||
) &&
|
||||
liveAgent &&
|
||||
!llmManager.isLoadingProviders && (
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
)}
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
</Section>
|
||||
<Spacer rem={1.5} />
|
||||
</Fade>
|
||||
@@ -961,19 +936,17 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
isSearch ? "h-[14px]" : "h-0"
|
||||
)}
|
||||
/>
|
||||
{appFocus.isChat() &&
|
||||
liveAgent &&
|
||||
!llmManager.isLoadingProviders && (
|
||||
<div className="pb-1">
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{appFocus.isChat() && (
|
||||
<div className="pb-1">
|
||||
<ModelSelector
|
||||
llmManager={llmManager}
|
||||
selectedModels={multiModel.selectedModels}
|
||||
onAdd={multiModel.addModel}
|
||||
onRemove={multiModel.removeModel}
|
||||
onReplace={multiModel.replaceModel}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<AppInputBar
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={
|
||||
|
||||
@@ -15,7 +15,11 @@ import { SvgArrowExchange, SvgSettings, SvgTrash } from "@opal/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import { getProvider } from "@/lib/llmConfig";
|
||||
import {
|
||||
getProviderDisplayName,
|
||||
getProviderIcon,
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { deleteLlmProvider, setDefaultLlmModel } from "@/lib/llmConfig/svc";
|
||||
import { Horizontal as HorizontalInput } from "@/layouts/input-layouts";
|
||||
@@ -29,6 +33,19 @@ import {
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import { getModalForExistingProvider } from "@/sections/modals/llmConfig/getModal";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import AnthropicModal from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import OllamaModal from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import AzureModal from "@/sections/modals/llmConfig/AzureModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
@@ -55,6 +72,51 @@ const PROVIDER_DISPLAY_ORDER: string[] = [
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
];
|
||||
|
||||
const PROVIDER_MODAL_MAP: Record<
|
||||
string,
|
||||
(
|
||||
shouldMarkAsDefault: boolean,
|
||||
onOpenChange: (open: boolean) => void
|
||||
) => React.ReactNode
|
||||
> = {
|
||||
openai: (d, onOpenChange) => (
|
||||
<OpenAIModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
anthropic: (d, onOpenChange) => (
|
||||
<AnthropicModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
ollama_chat: (d, onOpenChange) => (
|
||||
<OllamaModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
azure: (d, onOpenChange) => (
|
||||
<AzureModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
bedrock: (d, onOpenChange) => (
|
||||
<BedrockModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
vertex_ai: (d, onOpenChange) => (
|
||||
<VertexAIModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
openrouter: (d, onOpenChange) => (
|
||||
<OpenRouterModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
lm_studio: (d, onOpenChange) => (
|
||||
<LMStudioModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
litellm_proxy: (d, onOpenChange) => (
|
||||
<LiteLLMProxyModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
bifrost: (d, onOpenChange) => (
|
||||
<BifrostModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
openai_compatible: (d, onOpenChange) => (
|
||||
<OpenAICompatibleModal
|
||||
shouldMarkAsDefault={d}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// ExistingProviderCard — card for configured (existing) providers
|
||||
// ============================================================================
|
||||
@@ -63,12 +125,14 @@ interface ExistingProviderCardProps {
|
||||
provider: LLMProviderView;
|
||||
isDefault: boolean;
|
||||
isLastProvider: boolean;
|
||||
defaultModelName?: string;
|
||||
}
|
||||
|
||||
function ExistingProviderCard({
|
||||
provider,
|
||||
isDefault,
|
||||
isLastProvider,
|
||||
defaultModelName,
|
||||
}: ExistingProviderCardProps) {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
@@ -86,14 +150,8 @@ function ExistingProviderCard({
|
||||
}
|
||||
};
|
||||
|
||||
const { icon, companyName, Modal } = getProvider(provider.provider, provider);
|
||||
|
||||
return (
|
||||
<>
|
||||
{isOpen && (
|
||||
<Modal existingLlmProvider={provider} onOpenChange={setIsOpen} />
|
||||
)}
|
||||
|
||||
{deleteModal.isOpen && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgTrash}
|
||||
@@ -144,9 +202,9 @@ function ExistingProviderCard({
|
||||
onClick={() => setIsOpen(true)}
|
||||
>
|
||||
<CardLayout.Header
|
||||
icon={icon}
|
||||
icon={getProviderIcon(provider.provider)}
|
||||
title={provider.name}
|
||||
description={companyName}
|
||||
description={getProviderDisplayName(provider.provider)}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
tag={isDefault ? { title: "Default", color: "blue" } : undefined}
|
||||
@@ -178,6 +236,8 @@ function ExistingProviderCard({
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
{isOpen &&
|
||||
getModalForExistingProvider(provider, setIsOpen, defaultModelName)}
|
||||
</SelectCard>
|
||||
</Hoverable.Root>
|
||||
</>
|
||||
@@ -191,11 +251,18 @@ function ExistingProviderCard({
|
||||
interface NewProviderCardProps {
|
||||
provider: WellKnownLLMProviderDescriptor;
|
||||
isFirstProvider: boolean;
|
||||
formFn: (
|
||||
shouldMarkAsDefault: boolean,
|
||||
onOpenChange: (open: boolean) => void
|
||||
) => React.ReactNode;
|
||||
}
|
||||
|
||||
function NewProviderCard({ provider, isFirstProvider }: NewProviderCardProps) {
|
||||
function NewProviderCard({
|
||||
provider,
|
||||
isFirstProvider,
|
||||
formFn,
|
||||
}: NewProviderCardProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { icon, productName, companyName, Modal } = getProvider(provider.name);
|
||||
|
||||
return (
|
||||
<SelectCard
|
||||
@@ -205,9 +272,9 @@ function NewProviderCard({ provider, isFirstProvider }: NewProviderCardProps) {
|
||||
onClick={() => setIsOpen(true)}
|
||||
>
|
||||
<CardLayout.Header
|
||||
icon={icon}
|
||||
title={productName}
|
||||
description={companyName}
|
||||
icon={getProviderIcon(provider.name)}
|
||||
title={getProviderProductName(provider.name)}
|
||||
description={getProviderDisplayName(provider.name)}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
@@ -223,9 +290,7 @@ function NewProviderCard({ provider, isFirstProvider }: NewProviderCardProps) {
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
{isOpen && (
|
||||
<Modal shouldMarkAsDefault={isFirstProvider} onOpenChange={setIsOpen} />
|
||||
)}
|
||||
{isOpen && formFn(isFirstProvider, setIsOpen)}
|
||||
</SelectCard>
|
||||
);
|
||||
}
|
||||
@@ -242,7 +307,6 @@ function NewCustomProviderCard({
|
||||
isFirstProvider,
|
||||
}: NewCustomProviderCardProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { icon, productName, companyName, Modal } = getProvider("custom");
|
||||
|
||||
return (
|
||||
<SelectCard
|
||||
@@ -252,9 +316,9 @@ function NewCustomProviderCard({
|
||||
onClick={() => setIsOpen(true)}
|
||||
>
|
||||
<CardLayout.Header
|
||||
icon={icon}
|
||||
title={productName}
|
||||
description={companyName}
|
||||
icon={getProviderIcon("custom")}
|
||||
title={getProviderProductName("custom")}
|
||||
description={getProviderDisplayName("custom")}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
@@ -271,7 +335,10 @@ function NewCustomProviderCard({
|
||||
}
|
||||
/>
|
||||
{isOpen && (
|
||||
<Modal shouldMarkAsDefault={isFirstProvider} onOpenChange={setIsOpen} />
|
||||
<CustomModal
|
||||
shouldMarkAsDefault={isFirstProvider}
|
||||
onOpenChange={setIsOpen}
|
||||
/>
|
||||
)}
|
||||
</SelectCard>
|
||||
);
|
||||
@@ -281,7 +348,7 @@ function NewCustomProviderCard({
|
||||
// LLMConfigurationPage — main page component
|
||||
// ============================================================================
|
||||
|
||||
export default function LLMConfigurationPage() {
|
||||
export default function LLMProviderConfigurationPage() {
|
||||
const { mutate } = useSWRConfig();
|
||||
const { llmProviders: existingLlmProviders, defaultText } =
|
||||
useAdminLLMProviders();
|
||||
@@ -402,6 +469,11 @@ export default function LLMConfigurationPage() {
|
||||
provider={provider}
|
||||
isDefault={defaultText?.provider_id === provider.id}
|
||||
isLastProvider={sortedProviders.length === 1}
|
||||
defaultModelName={
|
||||
defaultText?.provider_id === provider.id
|
||||
? defaultText.model_name
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@@ -435,13 +507,23 @@ export default function LLMConfigurationPage() {
|
||||
(bIndex === -1 ? Infinity : bIndex)
|
||||
);
|
||||
})
|
||||
.map((provider) => (
|
||||
<NewProviderCard
|
||||
key={provider.name}
|
||||
provider={provider}
|
||||
isFirstProvider={isFirstProvider}
|
||||
/>
|
||||
))}
|
||||
.map((provider) => {
|
||||
const formFn = PROVIDER_MODAL_MAP[provider.name];
|
||||
if (!formFn) {
|
||||
toast.error(
|
||||
`No modal mapping for provider "${provider.name}".`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<NewProviderCard
|
||||
key={provider.name}
|
||||
provider={provider}
|
||||
isFirstProvider={isFirstProvider}
|
||||
formFn={formFn}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
<NewCustomProviderCard isFirstProvider={isFirstProvider} />
|
||||
</div>
|
||||
</GeneralLayouts.Section>
|
||||
@@ -352,7 +352,6 @@ const ChatScrollContainer = React.memo(
|
||||
key={sessionId}
|
||||
ref={scrollContainerRef}
|
||||
data-testid="chat-scroll-container"
|
||||
data-chat-scroll
|
||||
className={cn(
|
||||
"flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-hidden",
|
||||
hideScrollbar ? "no-scrollbar" : "default-scrollbar"
|
||||
|
||||
@@ -50,7 +50,7 @@ function BifrostModalInternals({
|
||||
const { models, error } = await fetchBifrostModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.BIFROST,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import {
|
||||
@@ -29,8 +29,9 @@ import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { Button, Card, EmptyMessageCard } from "@opal/components";
|
||||
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
import { SvgMinusCircle, SvgPlusCircle, SvgRefreshCw } from "@opal/icons";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
@@ -110,6 +111,95 @@ function ModelConfigurationItem({
|
||||
);
|
||||
}
|
||||
|
||||
interface FetchedModel {
|
||||
name: string;
|
||||
display_name: string;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
}
|
||||
|
||||
function FetchModelsButton({ provider }: { provider: string }) {
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const [isFetching, setIsFetching] = useState(false);
|
||||
const formikProps = useFormikContext<{
|
||||
api_base?: string;
|
||||
api_key?: string;
|
||||
api_version?: string;
|
||||
model_configurations: CustomModelConfiguration[];
|
||||
}>();
|
||||
|
||||
useEffect(() => {
|
||||
return () => abortRef.current?.abort();
|
||||
}, []);
|
||||
|
||||
async function handleFetch() {
|
||||
abortRef.current?.abort();
|
||||
const controller = new AbortController();
|
||||
abortRef.current = controller;
|
||||
setIsFetching(true);
|
||||
try {
|
||||
const response = await fetch("/api/admin/llm/custom/available-models", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider,
|
||||
api_base: formikProps.values.api_base || undefined,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
api_version: formikProps.values.api_version || undefined,
|
||||
}),
|
||||
signal: controller.signal,
|
||||
});
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
}
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
const fetched: FetchedModel[] = await response.json();
|
||||
const existing = formikProps.values.model_configurations;
|
||||
const existingNames = new Set(existing.map((m) => m.name));
|
||||
const newModels: CustomModelConfiguration[] = fetched
|
||||
.filter((m) => !existingNames.has(m.name))
|
||||
.map((m) => ({
|
||||
name: m.name,
|
||||
display_name: m.display_name !== m.name ? m.display_name : "",
|
||||
max_input_tokens: m.max_input_tokens,
|
||||
supports_image_input: m.supports_image_input,
|
||||
}));
|
||||
// Replace empty placeholder rows, then merge
|
||||
const nonEmpty = existing.filter((m) => m.name.trim() !== "");
|
||||
formikProps.setFieldValue("model_configurations", [
|
||||
...nonEmpty,
|
||||
...newModels,
|
||||
]);
|
||||
toast.success(`Fetched ${fetched.length} models`);
|
||||
} catch (err) {
|
||||
if (err instanceof DOMException && err.name === "AbortError") return;
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
} finally {
|
||||
if (!controller.signal.aborted) {
|
||||
setIsFetching(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={isFetching ? SimpleLoader : SvgRefreshCw}
|
||||
onClick={handleFetch}
|
||||
disabled={isFetching || !provider}
|
||||
type="button"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function ModelConfigurationList() {
|
||||
const formikProps = useFormikContext<{
|
||||
model_configurations: CustomModelConfiguration[];
|
||||
@@ -222,6 +312,24 @@ function ProviderNameSelect({ disabled }: { disabled?: boolean }) {
|
||||
);
|
||||
}
|
||||
|
||||
function ModelsHeader() {
|
||||
const { values } = useFormikContext<{ provider: string }>();
|
||||
return (
|
||||
<InputLayouts.Horizontal
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
nonInteractive
|
||||
center
|
||||
>
|
||||
{values.provider ? (
|
||||
<FetchModelsButton provider={values.provider} />
|
||||
) : (
|
||||
<div />
|
||||
)}
|
||||
</InputLayouts.Horizontal>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Custom Config Processing ─────────────────────────────────────────────────
|
||||
|
||||
function keyValueListToDict(items: KeyValue[]): Record<string, string> {
|
||||
@@ -424,13 +532,7 @@ export default function CustomModal({
|
||||
<InputLayouts.FieldSeparator />
|
||||
<Section gap={0.5}>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Content
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
widthVariant="full"
|
||||
/>
|
||||
<ModelsHeader />
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<Card padding="sm">
|
||||
|
||||
@@ -52,7 +52,7 @@ function LiteLLMProxyModalInternals({
|
||||
const { models, error } = await fetchLiteLLMProxyModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.LITELLM_PROXY,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
@@ -52,7 +52,7 @@ function OpenRouterModalInternals({
|
||||
const { models, error } = await fetchOpenRouterModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.OPENROUTER,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
|
||||
75
web/src/sections/modals/llmConfig/getModal.tsx
Normal file
75
web/src/sections/modals/llmConfig/getModal.tsx
Normal file
@@ -0,0 +1,75 @@
|
||||
import { LLMProviderName, LLMProviderView } from "@/interfaces/llm";
|
||||
import AnthropicModal from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import OllamaModal from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import AzureModal from "@/sections/modals/llmConfig/AzureModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
export function getModalForExistingProvider(
|
||||
provider: LLMProviderView,
|
||||
onOpenChange?: (open: boolean) => void,
|
||||
defaultModelName?: string
|
||||
) {
|
||||
const props = {
|
||||
existingLlmProvider: provider,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
};
|
||||
|
||||
const hasCustomConfig = provider.custom_config != null;
|
||||
|
||||
switch (provider.provider) {
|
||||
// These providers don't use custom_config themselves, so a non-null
|
||||
// custom_config means the provider was created via CustomModal.
|
||||
case LLMProviderName.OPENAI:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<OpenAIModal {...props} />
|
||||
);
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<AnthropicModal {...props} />
|
||||
);
|
||||
case LLMProviderName.AZURE:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<AzureModal {...props} />
|
||||
);
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<OpenRouterModal {...props} />
|
||||
);
|
||||
|
||||
// These providers legitimately store settings in custom_config,
|
||||
// so always use their dedicated modals.
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaModal {...props} />;
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIModal {...props} />;
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockModal {...props} />;
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioModal {...props} />;
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...props} />;
|
||||
case LLMProviderName.BIFROST:
|
||||
return <BifrostModal {...props} />;
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...props} />;
|
||||
default:
|
||||
return <CustomModal {...props} />;
|
||||
}
|
||||
}
|
||||
@@ -44,7 +44,11 @@ import useUsers from "@/hooks/useUsers";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { UserRole } from "@/lib/types";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { getProvider } from "@/lib/llmConfig";
|
||||
import {
|
||||
getProviderIcon,
|
||||
getProviderDisplayName,
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
|
||||
// ─── DisplayNameField ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -713,11 +717,9 @@ function ModalWrapperInner({
|
||||
? "No changes to save."
|
||||
: undefined;
|
||||
|
||||
const {
|
||||
icon: providerIcon,
|
||||
companyName: providerDisplayName,
|
||||
productName: providerProductName,
|
||||
} = getProvider(providerName);
|
||||
const providerIcon = getProviderIcon(providerName);
|
||||
const providerDisplayName = getProviderDisplayName(providerName);
|
||||
const providerProductName = getProviderProductName(providerName);
|
||||
|
||||
const title = llmProvider
|
||||
? `Configure "${llmProvider.name}"`
|
||||
|
||||
145
web/src/sections/onboarding/forms/getOnboardingForm.tsx
Normal file
145
web/src/sections/onboarding/forms/getOnboardingForm.tsx
Normal file
@@ -0,0 +1,145 @@
|
||||
import React from "react";
|
||||
import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
LLMProviderFormProps,
|
||||
} from "@/interfaces/llm";
|
||||
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import AnthropicModal from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import OllamaModal from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import AzureModal from "@/sections/modals/llmConfig/AzureModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
// Display info for LLM provider cards - title is the product name, displayName is the company/platform
|
||||
const PROVIDER_DISPLAY_INFO: Record<
|
||||
string,
|
||||
{ title: string; displayName: string }
|
||||
> = {
|
||||
[LLMProviderName.OPENAI]: { title: "GPT", displayName: "OpenAI" },
|
||||
[LLMProviderName.ANTHROPIC]: { title: "Claude", displayName: "Anthropic" },
|
||||
[LLMProviderName.OLLAMA_CHAT]: { title: "Ollama", displayName: "Ollama" },
|
||||
[LLMProviderName.AZURE]: {
|
||||
title: "Azure OpenAI",
|
||||
displayName: "Microsoft Azure Cloud",
|
||||
},
|
||||
[LLMProviderName.BEDROCK]: {
|
||||
title: "Amazon Bedrock",
|
||||
displayName: "AWS",
|
||||
},
|
||||
[LLMProviderName.VERTEX_AI]: {
|
||||
title: "Gemini",
|
||||
displayName: "Google Cloud Vertex AI",
|
||||
},
|
||||
[LLMProviderName.OPENROUTER]: {
|
||||
title: "OpenRouter",
|
||||
displayName: "OpenRouter",
|
||||
},
|
||||
[LLMProviderName.LM_STUDIO]: {
|
||||
title: "LM Studio",
|
||||
displayName: "LM Studio",
|
||||
},
|
||||
[LLMProviderName.LITELLM_PROXY]: {
|
||||
title: "LiteLLM Proxy",
|
||||
displayName: "LiteLLM Proxy",
|
||||
},
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: {
|
||||
title: "OpenAI-Compatible",
|
||||
displayName: "OpenAI-Compatible",
|
||||
},
|
||||
};
|
||||
|
||||
export function getProviderDisplayInfo(providerName: string): {
|
||||
title: string;
|
||||
displayName: string;
|
||||
} {
|
||||
return (
|
||||
PROVIDER_DISPLAY_INFO[providerName] ?? {
|
||||
title: providerName,
|
||||
displayName: providerName,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export interface OnboardingFormProps {
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor;
|
||||
isCustomProvider?: boolean;
|
||||
onboardingState: OnboardingState;
|
||||
onboardingActions: OnboardingActions;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
}
|
||||
|
||||
export function getOnboardingForm({
|
||||
llmDescriptor,
|
||||
isCustomProvider,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
onOpenChange,
|
||||
}: OnboardingFormProps): React.ReactNode {
|
||||
const providerName = isCustomProvider
|
||||
? "custom"
|
||||
: llmDescriptor?.name ?? "custom";
|
||||
|
||||
const sharedProps: LLMProviderFormProps = {
|
||||
variant: "onboarding" as const,
|
||||
shouldMarkAsDefault:
|
||||
(onboardingState?.data.llmProviders ?? []).length === 0,
|
||||
onboardingActions,
|
||||
onOpenChange,
|
||||
onSuccess: () => {
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
},
|
||||
};
|
||||
|
||||
// Handle custom provider
|
||||
if (isCustomProvider || !llmDescriptor) {
|
||||
return <CustomModal {...sharedProps} />;
|
||||
}
|
||||
|
||||
switch (llmDescriptor.name) {
|
||||
case LLMProviderName.OPENAI:
|
||||
return <OpenAIModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return <AnthropicModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.AZURE:
|
||||
return <AzureModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...sharedProps} />;
|
||||
|
||||
default:
|
||||
return <CustomModal {...sharedProps} />;
|
||||
}
|
||||
}
|
||||
@@ -4,29 +4,35 @@ import { memo, useState, useCallback } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button } from "@opal/components";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import LLMProviderCard from "@/sections/onboarding/components/LLMProviderCard";
|
||||
import LLMProviderCard from "../components/LLMProviderCard";
|
||||
import {
|
||||
OnboardingActions,
|
||||
OnboardingState,
|
||||
OnboardingStep,
|
||||
} from "@/interfaces/onboarding";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import { getProvider } from "@/lib/llmConfig";
|
||||
getOnboardingForm,
|
||||
getProviderDisplayInfo,
|
||||
} from "../forms/getOnboardingForm";
|
||||
import { Disabled } from "@opal/core";
|
||||
import ModelIcon from "@/app/admin/configuration/llm/ModelIcon";
|
||||
import { SvgCheckCircle, SvgCpu, SvgExternalLink } from "@opal/icons";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { useLLMProviderOptions } from "@/lib/hooks/useLLMProviderOptions";
|
||||
|
||||
type LLMStepProps = {
|
||||
state: OnboardingState;
|
||||
actions: OnboardingActions;
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
interface SelectedProvider {
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor;
|
||||
isCustomProvider: boolean;
|
||||
}
|
||||
|
||||
function LLMProviderSkeleton() {
|
||||
const LLMProviderSkeleton = () => {
|
||||
return (
|
||||
<div className="flex justify-between h-full w-full p-1 rounded-12 border border-border-01 bg-background-neutral-01 animate-pulse">
|
||||
<div className="flex gap-1 p-1 flex-1 min-w-0">
|
||||
@@ -41,11 +47,12 @@ function LLMProviderSkeleton() {
|
||||
<div className="h-6 w-16 bg-neutral-200 rounded" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
interface StackedProviderIconsProps {
|
||||
type StackedProviderIconsProps = {
|
||||
providers: string[];
|
||||
}
|
||||
};
|
||||
|
||||
const StackedProviderIcons = ({ providers }: StackedProviderIconsProps) => {
|
||||
if (!providers || providers.length === 0) {
|
||||
return null;
|
||||
@@ -82,157 +89,133 @@ const StackedProviderIcons = ({ providers }: StackedProviderIconsProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
interface LLMStepProps {
|
||||
state: OnboardingState;
|
||||
actions: OnboardingActions;
|
||||
disabled?: boolean;
|
||||
}
|
||||
const LLMStep = memo(
|
||||
({
|
||||
state: onboardingState,
|
||||
actions: onboardingActions,
|
||||
disabled,
|
||||
}: LLMStepProps) => {
|
||||
const { llmProviderOptions, isLoading } = useLLMProviderOptions();
|
||||
const llmDescriptors = llmProviderOptions ?? [];
|
||||
const LLMStepInner = ({
|
||||
state: onboardingState,
|
||||
actions: onboardingActions,
|
||||
disabled,
|
||||
}: LLMStepProps) => {
|
||||
const { llmProviderOptions, isLoading } = useLLMProviderOptions();
|
||||
const llmDescriptors = llmProviderOptions ?? [];
|
||||
|
||||
const [selectedProvider, setSelectedProvider] =
|
||||
useState<SelectedProvider | null>(null);
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [selectedProvider, setSelectedProvider] =
|
||||
useState<SelectedProvider | null>(null);
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
|
||||
const handleProviderClick = useCallback(
|
||||
(
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor,
|
||||
isCustomProvider: boolean = false
|
||||
) => {
|
||||
setSelectedProvider({ llmDescriptor, isCustomProvider });
|
||||
setIsModalOpen(true);
|
||||
},
|
||||
[]
|
||||
);
|
||||
const handleProviderClick = useCallback(
|
||||
(
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor,
|
||||
isCustomProvider: boolean = false
|
||||
) => {
|
||||
setSelectedProvider({ llmDescriptor, isCustomProvider });
|
||||
setIsModalOpen(true);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleModalClose = useCallback((open: boolean) => {
|
||||
setIsModalOpen(open);
|
||||
if (!open) {
|
||||
setSelectedProvider(null);
|
||||
}
|
||||
}, []);
|
||||
const handleModalClose = useCallback((open: boolean) => {
|
||||
setIsModalOpen(open);
|
||||
if (!open) {
|
||||
setSelectedProvider(null);
|
||||
}
|
||||
}, []);
|
||||
|
||||
if (
|
||||
onboardingState.currentStep === OnboardingStep.LlmSetup ||
|
||||
onboardingState.currentStep === OnboardingStep.Name
|
||||
) {
|
||||
const providerName = selectedProvider?.isCustomProvider
|
||||
? "custom"
|
||||
: selectedProvider?.llmDescriptor?.name ?? "custom";
|
||||
|
||||
const { Modal: ModalComponent } = getProvider(providerName);
|
||||
|
||||
const modalProps: LLMProviderFormProps = {
|
||||
variant: "onboarding" as const,
|
||||
shouldMarkAsDefault:
|
||||
(onboardingState?.data.llmProviders ?? []).length === 0,
|
||||
onboardingActions,
|
||||
onOpenChange: handleModalClose,
|
||||
onSuccess: () => {
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
},
|
||||
};
|
||||
|
||||
return (
|
||||
<Disabled disabled={disabled} allowClick>
|
||||
<div
|
||||
className="flex flex-col items-center justify-between w-full p-1 rounded-16 border border-border-01 bg-background-tint-00"
|
||||
aria-label="onboarding-llm-step"
|
||||
>
|
||||
<ContentAction
|
||||
icon={SvgCpu}
|
||||
title="Connect your LLM models"
|
||||
description="Onyx supports both self-hosted models and popular providers."
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="lg"
|
||||
rightChildren={
|
||||
<Button
|
||||
disabled={disabled}
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgExternalLink}
|
||||
href="/admin/configuration/llm"
|
||||
if (
|
||||
onboardingState.currentStep === OnboardingStep.LlmSetup ||
|
||||
onboardingState.currentStep === OnboardingStep.Name
|
||||
) {
|
||||
return (
|
||||
<Disabled disabled={disabled} allowClick>
|
||||
<div
|
||||
className="flex flex-col items-center justify-between w-full p-1 rounded-16 border border-border-01 bg-background-tint-00"
|
||||
aria-label="onboarding-llm-step"
|
||||
>
|
||||
<ContentAction
|
||||
icon={SvgCpu}
|
||||
title="Connect your LLM models"
|
||||
description="Onyx supports both self-hosted models and popular providers."
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="lg"
|
||||
rightChildren={
|
||||
<Button
|
||||
disabled={disabled}
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgExternalLink}
|
||||
href="/admin/configuration/llm"
|
||||
>
|
||||
View in Admin Panel
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<Separator />
|
||||
<div className="flex flex-wrap gap-1 [&>*:last-child:nth-child(odd)]:basis-full">
|
||||
{isLoading ? (
|
||||
Array.from({ length: 8 }).map((_, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="basis-[calc(50%-theme(spacing.1)/2)] grow"
|
||||
>
|
||||
View in Admin Panel
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<Separator />
|
||||
<div className="flex flex-wrap gap-1 [&>*:last-child:nth-child(odd)]:basis-full">
|
||||
{isLoading ? (
|
||||
Array.from({ length: 8 }).map((_, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className="basis-[calc(50%-theme(spacing.1)/2)] grow"
|
||||
>
|
||||
<LLMProviderSkeleton />
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<>
|
||||
{/* Render the selected provider form */}
|
||||
{selectedProvider && isModalOpen && (
|
||||
<ModalComponent {...modalProps} />
|
||||
)}
|
||||
|
||||
{/* Render provider cards */}
|
||||
{llmDescriptors.map((llmDescriptor) => {
|
||||
const { productName, companyName } = getProvider(
|
||||
llmDescriptor.name
|
||||
);
|
||||
return (
|
||||
<div
|
||||
key={llmDescriptor.name}
|
||||
className="basis-[calc(50%-theme(spacing.1)/2)] grow"
|
||||
>
|
||||
<LLMProviderCard
|
||||
title={productName}
|
||||
subtitle={companyName}
|
||||
providerName={llmDescriptor.name}
|
||||
disabled={disabled}
|
||||
isConnected={onboardingState.data.llmProviders?.some(
|
||||
(provider) => provider === llmDescriptor.name
|
||||
)}
|
||||
onClick={() =>
|
||||
handleProviderClick(llmDescriptor, false)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
<LLMProviderSkeleton />
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<>
|
||||
{/* Render the selected provider form */}
|
||||
{selectedProvider &&
|
||||
isModalOpen &&
|
||||
getOnboardingForm({
|
||||
llmDescriptor: selectedProvider.llmDescriptor,
|
||||
isCustomProvider: selectedProvider.isCustomProvider,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
onOpenChange: handleModalClose,
|
||||
})}
|
||||
|
||||
{/* Custom provider card */}
|
||||
<div className="basis-[calc(50%-theme(spacing.1)/2)] grow">
|
||||
<LLMProviderCard
|
||||
title="Custom LLM Provider"
|
||||
subtitle="LiteLLM Compatible APIs"
|
||||
disabled={disabled}
|
||||
isConnected={onboardingState.data.llmProviders?.some(
|
||||
(provider) => provider === "custom"
|
||||
)}
|
||||
onClick={() => handleProviderClick(undefined, true)}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Disabled>
|
||||
);
|
||||
}
|
||||
{/* Render provider cards */}
|
||||
{llmDescriptors.map((llmDescriptor) => {
|
||||
const displayInfo = getProviderDisplayInfo(
|
||||
llmDescriptor.name
|
||||
);
|
||||
return (
|
||||
<div
|
||||
key={llmDescriptor.name}
|
||||
className="basis-[calc(50%-theme(spacing.1)/2)] grow"
|
||||
>
|
||||
<LLMProviderCard
|
||||
title={displayInfo.title}
|
||||
subtitle={displayInfo.displayName}
|
||||
providerName={llmDescriptor.name}
|
||||
disabled={disabled}
|
||||
isConnected={onboardingState.data.llmProviders?.some(
|
||||
(provider) => provider === llmDescriptor.name
|
||||
)}
|
||||
onClick={() =>
|
||||
handleProviderClick(llmDescriptor, false)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Custom provider card */}
|
||||
<div className="basis-[calc(50%-theme(spacing.1)/2)] grow">
|
||||
<LLMProviderCard
|
||||
title="Custom LLM Provider"
|
||||
subtitle="LiteLLM Compatible APIs"
|
||||
disabled={disabled}
|
||||
isConnected={onboardingState.data.llmProviders?.some(
|
||||
(provider) => provider === "custom"
|
||||
)}
|
||||
onClick={() => handleProviderClick(undefined, true)}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Disabled>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
@@ -261,7 +244,7 @@ const LLMStep = memo(
|
||||
</button>
|
||||
);
|
||||
}
|
||||
);
|
||||
LLMStep.displayName = "LLMStep";
|
||||
};
|
||||
|
||||
const LLMStep = memo(LLMStepInner);
|
||||
export default LLMStep;
|
||||
|
||||
Reference in New Issue
Block a user