Compare commits

...

17 Commits

Author SHA1 Message Date
Nikolas Garza
cf19d0df4f feat(helm): add Prometheus metrics ports and Services for celery workers (#9630) 2026-03-27 08:03:48 +00:00
Danelegend
86a6a4c134 refactor(indexing): Vespa & Opensearch index function use Iterable (#9384) 2026-03-27 04:36:59 +00:00
SubashMohan
146b5449d2 feat: configurable file upload size and token limits via admin settings (#9232) 2026-03-27 04:23:16 +00:00
Jamison Lahman
b66991b5c5 chore(devtools): ods trace (#9688)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-27 03:56:38 +00:00
dependabot[bot]
9cb76dc027 chore(deps-dev): bump picomatch from 2.3.1 to 2.3.2 in /web (#9691)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 02:22:22 +00:00
dependabot[bot]
f66891d19e chore(deps-dev): bump handlebars from 4.7.8 to 4.7.9 in /web (#9689)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-27 01:41:29 +00:00
Nikolas Garza
c07c952ad5 chore(greptile): add nginx routing rule for non-api backend routes (#9687) 2026-03-27 00:34:15 +00:00
Nikolas Garza
be7f40a28a fix(nginx): route /scim/* to api_server (#9686) 2026-03-26 17:21:57 -07:00
Evan Lohn
26f941b5da perf: perm sync start time (#9685) 2026-03-27 00:07:53 +00:00
Jamison Lahman
b9e84c42a8 feat(providers): allow deleting all types of providers (#9625)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-26 15:20:56 -07:00
Bo-Onyx
0a1df52c2f feat(hook): Hook Form Modal Polish. (#9683) 2026-03-26 22:12:33 +00:00
Nikolas Garza
306b0d452f fix(billing): retry claimLicense up to 3x after Stripe checkout return (#9669) 2026-03-26 21:06:19 +00:00
Justin Tahara
5fdb34ba8e feat(llm): add Bifrost gateway frontend modal and provider registration (#9617) 2026-03-26 20:50:25 +00:00
Jamison Lahman
2d066631e3 fix(voice): dont soft-delete providers (#9679) 2026-03-26 19:26:32 +00:00
Evan Lohn
5c84f6c61b fix(jira): large batches fail json decode (#9677) 2026-03-26 18:53:37 +00:00
Nikolas Garza
899179d4b6 fix(api-key): clarify upgrade message for trial accounts (#9678) 2026-03-26 18:32:41 +00:00
Bo-Onyx
80d6bafc74 feat(hook): Hook connect/manage modal (#9645) 2026-03-26 18:16:33 +00:00
88 changed files with 5956 additions and 708 deletions

View File

@@ -24,6 +24,16 @@ When hardcoding a boolean variable to a constant value, remove the variable enti
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
## Nginx Routing — New Backend Routes
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
## Full vs Lite Deployments
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.

View File

@@ -0,0 +1,35 @@
"""remove voice_provider deleted column
Revision ID: 1d78c0ca7853
Revises: a3f8b2c1d4e5
Create Date: 2026-03-26 11:30:53.883127
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "1d78c0ca7853"
down_revision = "a3f8b2c1d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Hard-delete any soft-deleted rows before dropping the column
op.execute("DELETE FROM voice_provider WHERE deleted = true")
op.drop_column("voice_provider", "deleted")
def downgrade() -> None:
op.add_column(
"voice_provider",
sa.Column(
"deleted",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)

View File

@@ -473,6 +473,8 @@ def connector_permission_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_connector=True,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(

View File

@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import HierarchyNode
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call
@@ -105,9 +106,11 @@ def _get_slack_document_access(
slack_connector: SlackConnector,
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
callback: IndexingHeartbeatInterface | None,
indexing_start: SecondsSinceUnixEpoch | None = None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback
callback=callback,
start=indexing_start,
)
for doc_metadata_batch in slim_doc_generator:
@@ -180,9 +183,15 @@ def slack_doc_sync(
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.set_credentials_provider(provider)
indexing_start_ts: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
yield from _get_slack_document_access(
slack_connector,
slack_connector=slack_connector,
channel_permissions=channel_permissions,
callback=callback,
indexing_start=indexing_start_ts,
)

View File

@@ -6,6 +6,7 @@ from onyx.access.models import ElementExternalAccess
from onyx.access.models import ExternalAccess
from onyx.access.models import NodeExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
@@ -40,10 +41,19 @@ def generic_doc_sync(
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
indexing_start: SecondsSinceUnixEpoch | None = (
cc_pair.connector.indexing_start.timestamp()
if cc_pair.connector.indexing_start is not None
else None
)
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
start=indexing_start,
callback=callback,
):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -44,6 +44,31 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
# User Facing Features Configs
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
# Hard ceiling for the admin-configurable file upload size (in MB).
# Self-hosted customers can raise or lower this via the environment variable.
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
if _raw_max_upload_size_mb < 0:
logger.warning(
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
_raw_max_upload_size_mb,
)
_raw_max_upload_size_mb = 250
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
# Default fallback for the per-user file upload size limit (in MB) when no
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
# runtime so this never silently exceeds the hard ceiling.
_raw_default_upload_size_mb = int(
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
)
if _raw_default_upload_size_mb < 0:
logger.warning(
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
_raw_default_upload_size_mb,
)
_raw_default_upload_size_mb = 100
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
) # 1 day
@@ -61,17 +86,6 @@ CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# Maximum upload size for a single user file (chat/projects) in MB.
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"

View File

@@ -890,8 +890,8 @@ class ConfluenceConnector(
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
@@ -915,8 +915,8 @@ class ConfluenceConnector(
self.confluence_client, doc_id, restrictions, ancestors
) or space_level_access_info.get(page_space_key)
# Query pages
page_query = self.base_cql_page_query + self.cql_label_filter
# Query pages (with optional time filtering for indexing_start)
page_query = self._construct_page_cql_query(start, end)
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -950,7 +950,9 @@ class ConfluenceConnector(
# Query attachments for each page
page_hierarchy_node_yielded = False
attachment_query = self._construct_attachment_query(_get_page_id(page))
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,

View File

@@ -10,6 +10,7 @@ from datetime import timedelta
from datetime import timezone
from typing import Any
import requests
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
@@ -239,29 +240,53 @@ def enhanced_search_ids(
)
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO: move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
def _bulk_fetch_request(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
# Prepare the payload according to Jira API v3 specification
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
# Only restrict fields if specified, might want to explicitly do this in the future
# to avoid reading unnecessary data
payload["fields"] = fields.split(",") if fields else ["*all"]
resp = jira_client._session.post(bulk_fetch_path, json=payload)
return resp.json()["issues"]
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
f"be decoded as JSON (response too large or truncated)."
)
raise
mid = len(issue_ids) // 2
logger.warning(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise e
raise
return [
Issue(jira_client._options, jira_client._session, raw=issue)
for issue in response["issues"]
for issue in raw_issues
]

View File

@@ -1765,7 +1765,11 @@ class SharepointConnector(
checkpoint.current_drive_delta_next_link = None
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
def _fetch_slim_documents_from_sharepoint(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateSlimDocumentOutput:
site_descriptors = self._filter_excluded_sites(
self.site_descriptors or self.fetch_sites()
)
@@ -1786,7 +1790,9 @@ class SharepointConnector(
# Process site documents if flag is True
if self.include_site_documents:
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
site_descriptor=site_descriptor
site_descriptor=site_descriptor,
start=start,
end=end,
):
if self._is_driveitem_excluded(driveitem):
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
@@ -1841,7 +1847,9 @@ class SharepointConnector(
# Process site pages if flag is True
if self.include_site_pages:
site_pages = self._fetch_site_pages(site_descriptor)
site_pages = self._fetch_site_pages(
site_descriptor, start=start, end=end
)
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
@@ -2565,12 +2573,22 @@ class SharepointConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
) -> GenerateSlimDocumentOutput:
yield from self._fetch_slim_documents_from_sharepoint()
start_dt = (
datetime.fromtimestamp(start, tz=timezone.utc)
if start is not None
else None
)
end_dt = (
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
)
yield from self._fetch_slim_documents_from_sharepoint(
start=start_dt,
end=end_dt,
)
if __name__ == "__main__":

View File

@@ -516,6 +516,8 @@ def _get_all_doc_ids(
] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
workspace_url: str | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
@@ -546,6 +548,8 @@ def _get_all_doc_ids(
client=client,
channel=channel,
callback=callback,
oldest=str(start) if start else None, # 0.0 -> None intentionally
latest=str(end) if end is not None else None,
)
for message_batch in channel_message_batches:
@@ -847,8 +851,8 @@ class SlackConnector(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
@@ -861,6 +865,8 @@ class SlackConnector(
msg_filter_func=self.msg_filter_func,
callback=callback,
workspace_url=self._workspace_url,
start=start,
end=end,
)
def _load_from_checkpoint(

View File

@@ -3135,8 +3135,6 @@ class VoiceProvider(Base):
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(
select(VoiceProvider)
.where(VoiceProvider.deleted.is_(False))
.order_by(VoiceProvider.name)
).all()
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int, include_deleted: bool = False
db_session: Session, provider_id: int
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
if not include_deleted:
stmt = stmt.where(VoiceProvider.deleted.is_(False))
return db_session.scalar(stmt)
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.is_default_stt.is_(True))
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.is_default_tts.is_(True))
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
)
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.provider_type == provider_type)
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
)
@@ -119,10 +108,10 @@ def upsert_voice_provider(
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Soft-delete a voice provider by ID."""
"""Delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
provider.deleted = True
db_session.delete(provider)
db_session.flush()

View File

@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
out against a nonexistent Vespa/OpenSearch instance.
"""
from collections.abc import Iterable
from typing import Any
from onyx.context.search.models import IndexFilters
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
# ------------------------------------------------------------------
def index(
self,
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
index_batch_params: IndexBatchParams, # noqa: ARG002
) -> set[DocumentInsertionRecord]:
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[DocumentInsertionRecord]:
"""
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
it is done automatically outside of this code.
Parameters:
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- chunks: Document chunks with all of the information needed for
indexing to the document index.
- tenant_id: The tenant id of the user whose chunks are being indexed
- large_chunks_enabled: Whether large chunks are enabled

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Iterable
from typing import Self
from pydantic import BaseModel
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
"""Indexes an iterable of document chunks into the document index.
This is often a batch operation including chunks from multiple
documents.

View File

@@ -1,5 +1,5 @@
import json
from collections import defaultdict
from collections.abc import Iterable
from typing import Any
import httpx
@@ -351,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
@@ -647,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata, # noqa: ARG002
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
"""Indexes a list of document chunks into the document index.
"""Indexes an iterable of document chunks into the document index.
Groups chunks by document ID and for each document, deletes existing
chunks and indexes the new chunks in bulk.
@@ -673,29 +673,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
document is newly indexed or had already existed and was just
updated.
"""
# Group chunks by document ID.
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
list
total_chunks = sum(
cc.new_chunk_cnt
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
)
for chunk in chunks:
doc_id_to_chunks[chunk.source_document.id].append(chunk)
logger.debug(
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
f"documents for index {self._index_name}."
)
document_indexing_results: list[DocumentInsertionRecord] = []
# Try to index per-document.
for _, chunks in doc_id_to_chunks.items():
deleted_doc_ids: set[str] = set()
# Buffer chunks per document as they arrive from the iterable.
# When the document ID changes flush the buffered chunks.
current_doc_id: str | None = None
current_chunks: list[DocMetadataAwareIndexChunk] = []
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
assert len(doc_chunks) > 0, "doc_chunks is empty"
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
# Do this before deleting existing chunks to reduce the amount of
# time the document index has no content for a given document, and
# to reduce the chance of entering a state where we delete chunks,
# then some error happens, and never successfully index new chunks.
# Since we are doing this in batches, an error occurring midway
# can result in a state where chunks are deleted and not all the
# new chunks have been indexed.
chunk_batch: list[DocumentChunk] = [
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
_convert_onyx_chunk_to_opensearch_document(chunk)
for chunk in doc_chunks
]
onyx_document: Document = chunks[0].source_document
onyx_document: Document = doc_chunks[0].source_document
# First delete the doc's chunks from the index. This is so that
# there are no dangling chunks in the index, in the event that the
# new document's content contains fewer chunks than the previous
@@ -704,22 +709,40 @@ class OpenSearchDocumentIndex(DocumentIndex):
# if the chunk count has actually decreased. This assumes that
# overlapping chunks are perfectly overwritten. If we can't
# guarantee that then we need the code as-is.
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
# If we see that chunks were deleted we assume the doc already
# existed.
document_insertion_record = DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
if onyx_document.id not in deleted_doc_ids:
num_chunks_deleted = self.delete(
onyx_document.id, onyx_document.chunk_count
)
deleted_doc_ids.add(onyx_document.id)
# If we see that chunks were deleted we assume the doc already
# existed. We record the result before bulk_index_documents
# runs. If indexing raises, this entire result list is discarded
# by the caller's retry logic, so early recording is safe.
document_indexing_results.append(
DocumentInsertionRecord(
document_id=onyx_document.id,
already_existed=num_chunks_deleted > 0,
)
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
document_indexing_results.append(document_insertion_record)
for chunk in chunks:
doc_id = chunk.source_document.id
if doc_id != current_doc_id:
if current_chunks:
_flush_chunks(current_chunks)
current_doc_id = doc_id
current_chunks = [chunk]
else:
current_chunks.append(chunk)
if current_chunks:
_flush_chunks(current_chunks)
return document_indexing_results

View File

@@ -6,6 +6,7 @@ import re
import time
import urllib
import zipfile
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""

View File

@@ -1,6 +1,8 @@
import concurrent.futures
import logging
import random
from collections.abc import Generator
from collections.abc import Iterable
from typing import Any
from uuid import UUID
@@ -318,7 +320,7 @@ class VespaDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterable[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
@@ -338,22 +340,31 @@ class VespaDocumentIndex(DocumentIndex):
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
#
# Instead of materializing all cleaned chunks upfront, we stream them
# through a generator that cleans IDs and builds the original-ID mapping
# incrementally as chunks flow into Vespa.
def _clean_and_track(
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
id_map: dict[str, str],
seen_ids: set[str],
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
"""Cleans chunk IDs and builds the original-ID mapping
incrementally as chunks flow through, avoiding a separate
materialization pass."""
for chunk in chunks_iter:
original_id = chunk.source_document.id
cleaned = clean_chunk_id_copy(chunk)
cleaned_id = cleaned.source_document.id
# Needed so the final DocumentInsertionRecord returned can have
# the original document ID. cleaned_chunks might not contain IDs
# exactly as callers supplied them.
id_map[cleaned_id] = original_id
seen_ids.add(cleaned_id)
yield cleaned
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
new_document_id_to_original_document_id: dict[str, str] = {}
all_cleaned_doc_ids: set[str] = set()
existing_docs: set[str] = set()
@@ -409,7 +420,13 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
# Insert new Vespa documents.
# Insert new Vespa documents, streaming through the cleaning
# pipeline so chunks are never fully materialized.
cleaned_chunks = _clean_and_track(
chunks,
new_document_id_to_original_document_id,
all_cleaned_doc_ids,
)
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(
chunks=chunk_batch,
@@ -419,10 +436,6 @@ class VespaDocumentIndex(DocumentIndex):
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],

View File

@@ -29,6 +29,7 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
@@ -173,8 +174,10 @@ class UserFileIndexingAdapter:
[chunk.content for chunk in user_file_chunks]
)
user_file_id_to_raw_text[str(user_file_id)] = combined_content
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
token_count: int = (
count_tokens(combined_content, llm_tokenizer)
if llm_tokenizer
else 0
)
user_file_id_to_token_count[str(user_file_id)] = token_count
else:

View File

@@ -175,6 +175,32 @@ def get_tokenizer(
return _check_tokenizer_cache(provider_type, model_name)
# Max characters per encode() call.
_ENCODE_CHUNK_SIZE = 500_000
def count_tokens(
text: str,
tokenizer: BaseTokenizer,
token_limit: int | None = None,
) -> int:
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
If token_limit is provided and the text is large enough to require
multiple chunks (> 500k chars), stops early once the count exceeds it.
When early-exiting, the returned value exceeds token_limit but may be
less than the true full token count.
"""
if len(text) <= _ENCODE_CHUNK_SIZE:
return len(tokenizer.encode(text))
total = 0
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
if token_limit is not None and total > token_limit:
return total # Already over — skip remaining chunks
return total
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:

View File

@@ -44,11 +44,12 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
# ---------------------------------------------------------------------------
@@ -141,19 +142,11 @@ def _validate_endpoint(
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
# Any timeout (connect, read, or write) means the configured timeout_seconds
# is too low for this endpoint. Report as timeout so the UI directs the user
# to increase the timeout setting.
logger.warning(
"Hook endpoint validation: read/write timeout for %s",
"Hook endpoint validation: timeout for %s",
endpoint_url,
exc_info=exc,
)

View File

@@ -9,20 +9,15 @@ from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.natural_language_processing.utils import count_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -161,8 +156,8 @@ def categorize_uploaded_files(
document formats (.pdf, .docx, …) and falls back to a text-detection
heuristic for unknown extensions (.py, .js, .rs, …).
- Uses default tokenizer to compute token length.
- If token length > threshold, reject file (unless threshold skip is enabled).
- If text cannot be extracted, reject file.
- If token length exceeds the admin-configured threshold, reject file.
- If extension unsupported or text cannot be extracted, reject file.
- Otherwise marked as acceptable.
"""
@@ -173,36 +168,33 @@ def categorize_uploaded_files(
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
# Check if threshold checks should be skipped
skip_threshold = False
# Check global skip flag (works for both single-tenant and multi-tenant)
if SKIP_USERFILE_THRESHOLD:
skip_threshold = True
logger.info("Skipping userfile threshold check (global setting)")
# Check tenant-specific skip list (only applicable in multi-tenant)
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
try:
current_tenant_id = get_current_tenant_id()
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
if skip_threshold:
logger.info(
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
)
except RuntimeError as e:
logger.warning(f"Failed to get current tenant ID: {str(e)}")
# Derive limits from admin-configurable settings.
# For upload size: load_settings() resolves 0/None to a positive default.
# For token threshold: 0 means "no limit" (converted to None below).
settings = load_settings()
max_upload_size_mb = (
settings.user_file_max_upload_size_mb
) # always positive after load_settings()
max_upload_size_bytes = (
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
)
token_threshold_k = settings.file_token_count_threshold_k
token_threshold = (
token_threshold_k * 1000 if token_threshold_k else None
) # 0 → None = no limit
for upload in files:
try:
filename = get_safe_filename(upload)
# Size limit is a hard safety cap and is enforced even when token
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
# Size limit is a hard safety cap.
if max_upload_size_bytes is not None and is_upload_too_large(
upload, max_upload_size_bytes
):
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
)
)
continue
@@ -224,11 +216,11 @@ def categorize_uploaded_files(
)
continue
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
if token_threshold is not None and token_count > token_threshold:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
reason=f"Exceeds {token_threshold_k}K token limit",
)
)
else:
@@ -269,12 +261,14 @@ def categorize_uploaded_files(
)
continue
token_count = len(tokenizer.encode(text_content))
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
token_count = count_tokens(
text_content, tokenizer, token_limit=token_threshold
)
if token_threshold is not None and token_count > token_threshold:
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
reason=f"Exceeds {token_threshold_k}K token limit",
)
)
else:

View File

@@ -1524,6 +1524,7 @@ def get_bifrost_available_models(
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:

View File

@@ -463,3 +463,4 @@ class BifrostFinalModelResponse(BaseModel):
display_name: str # Human-readable name from Bifrost API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -9,7 +9,9 @@ from onyx import __version__ as onyx_version
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import is_user_admin
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.constants import NotificationType
from onyx.db.engine.sql_engine import get_session
@@ -17,10 +19,16 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Notification
from onyx.server.settings.models import Settings
from onyx.server.settings.models import UserSettings
@@ -41,6 +49,15 @@ basic_router = APIRouter(prefix="/settings")
def admin_put_settings(
settings: Settings, _: User = Depends(current_admin_user)
) -> None:
if (
settings.user_file_max_upload_size_mb is not None
and settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
)
store_settings(settings)
@@ -83,6 +100,16 @@ def fetch_settings(
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
),
default_file_token_count_threshold_k=(
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
),
)

View File

@@ -2,12 +2,19 @@ from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.constants import NotificationType
from onyx.configs.constants import QueryHistoryType
from onyx.db.models import Notification as NotificationDBModel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
class PageType(str, Enum):
CHAT = "chat"
@@ -78,7 +85,12 @@ class Settings(BaseModel):
# User Knowledge settings
user_knowledge_enabled: bool | None = True
user_file_max_upload_size_mb: int | None = None
user_file_max_upload_size_mb: int | None = Field(
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
)
file_token_count_threshold_k: int | None = Field(
default=None, ge=0 # thousands of tokens; None = context-aware default
)
# Connector settings
show_extra_connectors: bool | None = True
@@ -108,3 +120,14 @@ class UserSettings(Settings):
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
# Factory defaults so the frontend can show a "restore default" button.
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
default_file_token_count_threshold_k: int = Field(
default_factory=lambda: (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
)

View File

@@ -1,13 +1,19 @@
from onyx.cache.factory import get_cache_backend
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -51,9 +57,36 @@ def load_settings() -> Settings:
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
# Resolve context-aware defaults for token threshold.
# None = admin hasn't set a value yet → use context-aware default.
# 0 = admin explicitly chose "no limit" → preserve as-is.
if settings.file_token_count_threshold_k is None:
settings.file_token_count_threshold_k = (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
if DISABLE_VECTOR_DB
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
# Upload size: 0 and None are treated as "unset" (not "no limit") →
# fall back to min(configured default, hard ceiling).
if not settings.user_file_max_upload_size_mb:
settings.user_file_max_upload_size_mb = min(
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
MAX_ALLOWED_UPLOAD_SIZE_MB,
)
# Clamp to env ceiling so stale KV values are capped even if the
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
# was already saved (api.py only guards new writes).
if (
settings.user_file_max_upload_size_mb > 0
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
):
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
return settings

View File

@@ -191,25 +191,6 @@ IGNORED_SYNCING_TENANT_LIST = (
else None
)
# Global flag to skip userfile threshold for all users/tenants
SKIP_USERFILE_THRESHOLD = (
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
)
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
)
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
[
tenant.strip()
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
if tenant.strip()
]
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
else None
)
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"

View File

@@ -1,4 +1,6 @@
import time
from datetime import datetime
from datetime import timezone
import pytest
@@ -17,6 +19,10 @@ PRIVATE_CHANNEL_USERS = [
"test_user_2@onyx-test.com",
]
# Predates any test workspace messages, so the result set should match
# the "no start time" case while exercising the oldest= parameter.
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
pytestmark = pytest.mark.usefixtures("enable_ee")
@@ -105,15 +111,17 @@ def test_load_from_checkpoint_access__private_channel(
],
indirect=True,
)
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
def test_slim_documents_access__public_channel(
slack_connector: SlackConnector,
start_ts: float | None,
) -> None:
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
if not slack_connector.client:
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=0.0,
start=start_ts,
end=time.time(),
)
@@ -149,7 +157,7 @@ def test_slim_documents_access__private_channel(
raise RuntimeError("Web client must be defined")
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
start=0.0,
start=None,
end=time.time(),
)

View File

@@ -103,6 +103,11 @@ _EXPECTED_CONFLUENCE_GROUPS = [
user_emails={"oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="no yuhong allowed",
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
gives_anyone_access=False,
),
]

View File

@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
import time
import uuid
from collections.abc import Generator
from collections.abc import Iterator
import httpx
import pytest
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import DocMetadataAwareIndexChunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk(doc_id, chunk_id=i)
results = document_index.index(
chunks=chunk_gen(), indexing_metadata=metadata
)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False

View File

@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
import time
from collections.abc import Generator
from collections.abc import Iterator
import pytest
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
batch_retrieval=True,
)
assert len(inference_chunks) == 0
def test_index_accepts_generator(
self,
document_indices: list[DocumentIndex],
tenant_context: None, # noqa: ARG002
) -> None:
"""index() accepts a generator (any iterable), not just a list."""
for document_index in document_indices:
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
for i in range(3):
yield make_chunk("test_doc_gen", chunk_id=i)
index_batch_params = IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
results = document_index.index(chunk_gen(), index_batch_params)
assert len(results) == 1
record = results.pop()
assert record.document_id == "test_doc_gen"
assert record.already_existed is False

View File

@@ -0,0 +1,147 @@
from typing import Any
from unittest.mock import MagicMock
import pytest
import requests
from jira import JIRA
from jira.resources import Issue
from onyx.connectors.jira.connector import bulk_fetch_issues
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
return {
"id": issue_id,
"key": f"TEST-{issue_id}",
"fields": {"summary": f"Issue {issue_id}"},
}
def _mock_jira_client() -> MagicMock:
mock = MagicMock(spec=JIRA)
mock._options = {"server": "https://jira.example.com"}
mock._session = MagicMock()
mock._get_url = MagicMock(
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
)
return mock
def test_bulk_fetch_success() -> None:
"""Happy path: all issues fetched in one request."""
client = _mock_jira_client()
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
result = bulk_fetch_issues(client, ["1", "2", "3"])
assert len(result) == 3
assert all(isinstance(r, Issue) for r in result)
client._session.post.assert_called_once()
def test_bulk_fetch_splits_on_json_error() -> None:
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
client = _mock_jira_client()
call_count = 0
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
nonlocal call_count
call_count += 1
ids = json["issueIdsOrKeys"]
if len(ids) > 2:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 2294125
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
assert len(result) == 4
returned_ids = {r.raw["id"] for r in result}
assert returned_ids == {"1", "2", "3", "4"}
assert call_count > 1
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
"""A single issue that always fails JSON decode raises after splitting."""
client = _mock_jira_client()
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if "bad" in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 100
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "bad", "2"])
def test_bulk_fetch_non_json_error_propagates() -> None:
"""Non-JSONDecodeError exceptions still propagate."""
client = _mock_jira_client()
resp = MagicMock()
resp.json.side_effect = ValueError("something else broke")
client._session.post.return_value = resp
try:
bulk_fetch_issues(client, ["1"])
assert False, "Expected ValueError to propagate"
except ValueError:
pass
def test_bulk_fetch_with_fields() -> None:
"""Fields parameter is forwarded correctly."""
client = _mock_jira_client()
raw = [_make_raw_issue("1")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
bulk_fetch_issues(client, ["1"], fields="summary,description")
call_payload = client._session.post.call_args[1]["json"]
assert call_payload["fields"] == ["summary", "description"]
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
client = _mock_jira_client()
bad_id = "BAD"
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if bad_id in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"truncated", "doc", 999
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -31,6 +33,7 @@ def mock_jira_cc_pair(
"jira_base_url": jira_base_url,
"project_key": project_key,
}
mock_cc_pair.connector.indexing_start = None
return mock_cc_pair
@@ -65,3 +68,75 @@ def test_jira_permission_sync(
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
):
print(doc)
def test_jira_doc_sync_passes_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that generic_doc_sync derives indexing_start from cc_pair
and forwards it to retrieve_all_slim_docs_perm_sync."""
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
def test_jira_doc_sync_passes_none_when_no_indexing_start(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
"""Verify that indexing_start is None when the connector has no indexing_start set."""
mock_jira_cc_pair.connector.indexing_start = None
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
assert jira_connector._jira_client is not None
jira_connector._jira_client._options = MagicMock()
jira_connector._jira_client._options.return_value = {
"rest_api_version": JIRA_SERVER_API_VERSION
}
with patch.object(
type(jira_connector),
"retrieve_all_slim_docs_perm_sync",
return_value=iter([]),
) as mock_retrieve:
list(
jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
)
)
mock_retrieve.assert_called_once()
call_kwargs = mock_retrieve.call_args
assert call_kwargs.kwargs["start"] is None

View File

@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
class TestDeleteVoiceProvider:
"""Tests for delete_voice_provider."""
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
delete_voice_provider(mock_db_session, 1)
assert provider.deleted is True
mock_db_session.delete.assert_called_once_with(provider)
mock_db_session.flush.assert_called_once()
def test_does_nothing_when_provider_not_found(

View File

@@ -3,7 +3,7 @@
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
# ---------------------------------------------------------------------------
@@ -158,13 +158,11 @@ class TestValidateEndpoint:
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.cannot_connect
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(

View File

@@ -4,13 +4,23 @@ from unittest.mock import MagicMock
import pytest
from fastapi import UploadFile
from onyx.natural_language_processing import utils as nlp_utils
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.server.features.projects import projects_file_utils as utils
from onyx.server.settings.models import Settings
class _Tokenizer:
class _Tokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
return [1] * len(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
class _NonSeekableFile(BytesIO):
def tell(self) -> int:
@@ -29,10 +39,26 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
return UploadFile(filename=filename, file=BytesIO(content), size=None)
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
return Settings(
user_file_max_upload_size_mb=upload_size_mb,
file_token_count_threshold_k=token_threshold_k,
)
def _patch_common_dependencies(
monkeypatch: pytest.MonkeyPatch,
upload_size_mb: int = 1,
token_threshold_k: int = 100,
) -> None:
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
monkeypatch.setattr(
utils,
"load_settings",
lambda: _make_settings(upload_size_mb, token_threshold_k),
)
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
@@ -76,9 +102,8 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
def test_categorize_uploaded_files_accepts_size_under_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("small.png", size=99)
@@ -91,9 +116,7 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload_no_size("small.png", content=b"x" * 99)
@@ -106,12 +129,11 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
def test_categorize_uploaded_files_accepts_size_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("edge.png", size=100)
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
upload = _make_upload("edge.png", size=1048576)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
@@ -121,12 +143,10 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("large.png", size=101)
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -137,13 +157,11 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
small = _make_upload("small.png", size=50)
large = _make_upload("large.png", size=101)
large = _make_upload("large.png", size=1048577)
result = utils.categorize_uploaded_files([small, large], MagicMock())
@@ -153,15 +171,12 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
def test_categorize_uploaded_files_enforces_size_limit_always(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
upload = _make_upload("oversized.pdf", size=101)
upload = _make_upload("oversized.pdf", size=1048577)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
@@ -172,14 +187,12 @@ def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_sk
def test_categorize_uploaded_files_checks_size_before_text_extraction(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
extract_mock = MagicMock(return_value="this should not run")
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
oversized_doc = _make_upload("oversized.pdf", size=101)
oversized_doc = _make_upload("oversized.pdf", size=1048577)
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
extract_mock.assert_not_called()
@@ -188,40 +201,219 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_accepts_python_file(
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
"""A positive upload_size_mb is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
py_source = b'def hello():\n print("world")\n'
monkeypatch.setattr(
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
)
upload = _make_upload("script.py", size=len(py_source), content=py_source)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable[0].filename == "script.py"
assert len(result.rejected) == 0
def test_categorize_uploaded_files_rejects_binary_file(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
binary_content = bytes(range(256)) * 4
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
upload = _make_upload("huge.png", size=1048577, content=b"x")
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].filename == "data.bin"
assert "Unsupported file type" in result.rejected[0].reason
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A positive token_threshold_k is always enforced."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
upload = _make_upload("big_image.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
def test_categorize_no_token_limit_when_threshold_k_is_zero(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
monkeypatch.setattr(
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
)
upload = _make_upload("huge_image.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 0
assert len(result.acceptable) == 1
def test_categorize_both_limits_enforced(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Both positive limits are enforced; file exceeding token limit is rejected."""
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
upload = _make_upload("over_tokens.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 5K token limit"
def test_categorize_rejection_reason_contains_dynamic_values(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
# File within size limit but over token limit
upload = _make_upload("tokens.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert result.rejected[0].reason == "Exceeds 7K token limit"
# File over size limit
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
# --- count_tokens tests ---
def test_count_tokens_small_text() -> None:
"""Small text should be encoded in a single call and return correct count."""
tokenizer = _Tokenizer()
text = "hello world"
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_chunked_matches_single_call() -> None:
"""Chunked encoding should produce the same result as single-call for small text."""
tokenizer = _Tokenizer()
text = "a" * 1000
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
# _Tokenizer returns 1 token per char, so total should be 250
assert count_tokens(text, tokenizer) == 250
def test_count_tokens_with_token_limit_exits_early(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set and exceeded, count_tokens should stop early."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
tokenizer = _CountingTokenizer()
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
text = "a" * 500
result = count_tokens(text, tokenizer, token_limit=150)
assert result == 200 # 2 chunks × 100 tokens each
assert encode_call_count == 2, "Should have stopped after 2 chunks"
def test_count_tokens_with_token_limit_not_exceeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When token_limit is set but not exceeded, all chunks are encoded."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 250
result = count_tokens(text, tokenizer, token_limit=1000)
assert result == 250
def test_count_tokens_no_limit_encodes_all_chunks(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Without token_limit, all chunks are encoded regardless of count."""
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
tokenizer = _Tokenizer()
text = "a" * 500
result = count_tokens(text, tokenizer)
assert result == 500
# --- early exit via token_limit in categorize tests ---
def test_categorize_early_exits_tokenization_for_large_text(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Large text files should be rejected via early-exit tokenization
without encoding all chunks."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
# Let's use a bigger text
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
large_text = "x" * 5000 # 5000 tokens, threshold 1000
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
encode_call_count = 0
original_tokenizer = _Tokenizer()
class _CountingTokenizer(BaseTokenizer):
def encode(self, text: str) -> list[int]:
nonlocal encode_call_count
encode_call_count += 1
return original_tokenizer.encode(text)
def tokenize(self, text: str) -> list[str]:
return list(text)
def decode(self, _tokens: list[int]) -> str:
return ""
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.rejected) == 1
assert "token limit" in result.rejected[0].reason
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
assert (
encode_call_count < 50
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
def test_categorize_text_under_token_limit_accepted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Text files under the token threshold should be accepted with exact count."""
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
small_text = "x" * 500 # 500 tokens < 1000 threshold
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert result.acceptable_file_to_token_count["ok.txt"] == 500

View File

@@ -1,12 +1,23 @@
import pytest
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings import store as settings_store
from onyx.server.settings.models import (
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
)
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
from onyx.server.settings.models import Settings
class _FakeKvStore:
def __init__(self, data: dict | None = None) -> None:
self._data = data
def load(self, _key: str) -> dict:
raise KvKeyNotFoundError()
if self._data is None:
raise KvKeyNotFoundError()
return self._data
class _FakeCache:
@@ -20,13 +31,140 @@ class _FakeCache:
self._vals[key] = value.encode("utf-8")
def test_load_settings_includes_user_file_max_upload_size_mb(
def test_load_settings_uses_model_defaults_when_no_stored_value(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When no settings are stored (vector DB enabled), load_settings() should
resolve the default token threshold to 200."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 77
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
)
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled and no settings are stored, the token
threshold should default to 10000 (10M tokens)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
assert (
settings.file_token_count_threshold_k
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
)
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When vector DB is disabled but admin explicitly set a token threshold,
that value should be preserved (not overridden by the 10000 default)."""
stored = Settings(file_token_count_threshold_k=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 500
def test_load_settings_preserves_zero_token_threshold(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 means 'no limit' and should be preserved."""
stored = Settings(file_token_count_threshold_k=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
settings = settings_store.load_settings()
assert settings.file_token_count_threshold_k == 0
def test_load_settings_resolves_zero_upload_size_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
def test_load_settings_clamps_upload_size_to_env_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be clamped to the env-configured maximum."""
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 250
def test_load_settings_preserves_upload_size_within_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
be preserved unchanged."""
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 150
def test_load_settings_zero_upload_size_resolves_to_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A value of 0 should be treated as unset and resolved to the default,
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 100
def test_load_settings_default_clamped_to_max(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
the effective default should be min(DEFAULT, MAX)."""
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 50

View File

@@ -39,6 +39,22 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -39,6 +39,20 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -39,6 +39,23 @@ server {
# Conditionally include MCP location configuration
include /etc/nginx/conf.d/mcp.conf.inc;
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header X-Forwarded-Port $server_port;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
proxy_pass http://api_server;
}
# Match both /api/* and /openapi.json in a single rule
location ~ ^/(api|openapi.json)(/.*)?$ {
# Rewrite /api prefixed matched paths

View File

@@ -66,10 +66,3 @@ DB_READONLY_PASSWORD=password
# Show extra/uncommon connectors
# See https://docs.onyx.app/admins/connectors/overview for a full list of connectors
SHOW_EXTRA_CONNECTORS=False
# User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
#SKIP_USERFILE_THRESHOLD=false
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
#SKIP_USERFILE_THRESHOLD_TENANT_IDS=

View File

@@ -35,6 +35,10 @@ USER_AUTH_SECRET=""
## Chat Configuration
# HARD_DELETE_CHATS=
# MAX_ALLOWED_UPLOAD_SIZE_MB=250
# Default per-user upload size limit (MB) when no admin value is set.
# Automatically clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at runtime.
# DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=100
## Base URL for redirects
# WEB_DOMAIN=
@@ -42,13 +46,6 @@ USER_AUTH_SECRET=""
## Enterprise Features, requires a paid plan and licenses
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false
## User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
# SKIP_USERFILE_THRESHOLD=false
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
# SKIP_USERFILE_THRESHOLD_TENANT_IDS=
################################################################################
## SERVICES CONFIGURATIONS

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.36
version: 0.4.38
appVersion: latest
annotations:
category: Productivity

View File

@@ -0,0 +1,26 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docfetching.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9092
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -73,6 +73,10 @@ spec:
"-Q",
"connector_doc_fetching",
]
ports:
- name: metrics
containerPort: 9092
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_docfetching.resources | nindent 12 }}
envFrom:

View File

@@ -0,0 +1,26 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docprocessing.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9093
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -73,6 +73,10 @@ spec:
"-Q",
"docprocessing",
]
ports:
- name: metrics
containerPort: 9093
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_docprocessing.resources | nindent 12 }}
envFrom:

View File

@@ -0,0 +1,26 @@
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring-metrics
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
{{- end }}
metrics: "true"
spec:
type: ClusterIP
ports:
- port: 9096
targetPort: metrics
protocol: TCP
name: metrics
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -70,6 +70,10 @@ spec:
"-Q",
"monitoring",
]
ports:
- name: metrics
containerPort: 9096
protocol: TCP
resources:
{{- toYaml .Values.celery_worker_monitoring.resources | nindent 12 }}
envFrom:

View File

@@ -63,6 +63,22 @@ data:
}
{{- end }}
location ~ ^/scim(/.*)?$ {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $host;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_buffering off;
proxy_redirect off;
# timeout settings
proxy_connect_timeout {{ .Values.nginx.timeouts.connect }}s;
proxy_send_timeout {{ .Values.nginx.timeouts.send }}s;
proxy_read_timeout {{ .Values.nginx.timeouts.read }}s;
proxy_pass http://api_server;
}
location ~ ^/(api|openapi\.json)(/.*)?$ {
rewrite ^/api(/.*)$ $1 break;
proxy_set_header X-Real-IP $remote_addr;

View File

@@ -282,7 +282,7 @@ nginx:
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
# Workaround: Helm upgrade will restart if the following annotation value changes.
podAnnotations:
onyx.app/nginx-config-version: "2"
onyx.app/nginx-config-version: "3"
# Propagate DOMAIN into nginx so server_name continues to use the same env var
extraEnvs:
@@ -1285,11 +1285,5 @@ configMap:
DOMAIN: "localhost"
# Chat Configs
HARD_DELETE_CHATS: ""
# User File Upload Configuration
# Skip the token count threshold check (100,000 tokens) for uploaded files
# For self-hosted: set to true to skip for all users
SKIP_USERFILE_THRESHOLD: ""
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
# Maximum user upload file size in MB for chat/projects uploads
USER_FILE_MAX_UPLOAD_SIZE_MB: ""
MAX_ALLOWED_UPLOAD_SIZE_MB: ""
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB: ""

View File

@@ -28,7 +28,7 @@ Some commands require external tools to be installed and configured:
- **uv** - Required for `backend` commands
- Install from [docs.astral.sh/uv](https://docs.astral.sh/uv/)
- **GitHub CLI** (`gh`) - Required for `run-ci` and `cherry-pick` commands
- **GitHub CLI** (`gh`) - Required for `run-ci`, `cherry-pick`, and `trace` commands
- Install from [cli.github.com](https://cli.github.com/)
- Authenticate with `gh auth login`
@@ -412,6 +412,62 @@ The `compare` subcommand writes a `summary.json` alongside the report with aggre
counts (changed, added, removed, unchanged). The HTML report is only generated when
visual differences are detected.
### `trace` - View Playwright Traces from CI
Download Playwright trace artifacts from a GitHub Actions run and open them
with `playwright show-trace`. Traces are only generated for failing tests
(`retain-on-failure`).
```shell
ods trace [run-id-or-url]
```
The run can be specified as a numeric run ID, a full GitHub Actions URL, or
omitted to find the latest Playwright run for the current branch.
**Flags:**
| Flag | Default | Description |
|------|---------|-------------|
| `--branch`, `-b` | | Find latest run for this branch |
| `--pr` | | Find latest run for this PR number |
| `--project`, `-p` | | Filter to a specific project (`admin`, `exclusive`, `lite`) |
| `--list`, `-l` | `false` | List available traces without opening |
| `--no-open` | `false` | Download traces but don't open them |
When multiple traces are found, an interactive picker lets you select which
traces to open. Use arrow keys or `j`/`k` to navigate, `space` to toggle,
`a` to select all, `n` to deselect all, and `enter` to open. Falls back to a
plain-text prompt when no TTY is available.
Downloaded artifacts are cached in `/tmp/ods-traces/<run-id>/` so repeated
invocations for the same run are instant.
**Examples:**
```shell
# Latest run for the current branch
ods trace
# Specific run ID
ods trace 12345678
# Full GitHub Actions URL
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
# Latest run for a PR
ods trace --pr 9500
# Latest run for a specific branch
ods trace --branch main
# Only download admin project traces
ods trace --project admin
# List traces without opening
ods trace --list
```
### Testing Changes Locally (Dry Run)
Both `run-ci` and `cherry-pick` support `--dry-run` to test without making remote changes:

View File

@@ -55,6 +55,7 @@ func NewRootCommand() *cobra.Command {
cmd.AddCommand(NewWebCommand())
cmd.AddCommand(NewLatestStableTagCommand())
cmd.AddCommand(NewWhoisCommand())
cmd.AddCommand(NewTraceCommand())
return cmd
}

556
tools/ods/cmd/trace.go Normal file
View File

@@ -0,0 +1,556 @@
package cmd
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/onyx-dot-app/onyx/tools/ods/internal/git"
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
"github.com/onyx-dot-app/onyx/tools/ods/internal/tui"
)
const playwrightWorkflow = "Run Playwright Tests"
// TraceOptions holds options for the trace command
type TraceOptions struct {
Branch string
PR string
Project string
List bool
NoOpen bool
}
// traceInfo describes a single trace.zip found in the downloaded artifacts.
type traceInfo struct {
Path string // absolute path to trace.zip
Project string // project group extracted from artifact dir (e.g. "admin", "admin-shard-1")
TestDir string // test directory name (human-readable-ish)
}
// NewTraceCommand creates a new trace command
func NewTraceCommand() *cobra.Command {
opts := &TraceOptions{}
cmd := &cobra.Command{
Use: "trace [run-id-or-url]",
Short: "Download and view Playwright traces from GitHub Actions",
Long: `Download Playwright trace artifacts from a GitHub Actions run and open them
with 'playwright show-trace'.
The run can be specified as:
- A GitHub Actions run ID (numeric)
- A full GitHub Actions run URL
- Omitted, to find the latest Playwright run for the current branch
You can also look up the latest run by branch name or PR number.
Examples:
ods trace # latest run for current branch
ods trace 12345678 # specific run ID
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
ods trace --pr 9500 # latest run for PR #9500
ods trace --branch main # latest run for main branch
ods trace --project admin # only download admin project traces
ods trace --list # list available traces without opening`,
Args: cobra.MaximumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
runTrace(args, opts)
},
}
cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Find latest run for this branch")
cmd.Flags().StringVar(&opts.PR, "pr", "", "Find latest run for this PR number")
cmd.Flags().StringVarP(&opts.Project, "project", "p", "", "Filter to a specific project (admin, exclusive, lite)")
cmd.Flags().BoolVarP(&opts.List, "list", "l", false, "List available traces without opening")
cmd.Flags().BoolVar(&opts.NoOpen, "no-open", false, "Download traces but don't open them")
return cmd
}
// ghRun represents a GitHub Actions workflow run from `gh run list`
type ghRun struct {
DatabaseID int64 `json:"databaseId"`
Status string `json:"status"`
Conclusion string `json:"conclusion"`
HeadBranch string `json:"headBranch"`
URL string `json:"url"`
}
func runTrace(args []string, opts *TraceOptions) {
git.CheckGitHubCLI()
runID, err := resolveRunID(args, opts)
if err != nil {
log.Fatalf("Failed to resolve run: %v", err)
}
log.Infof("Using run ID: %s", runID)
destDir, err := downloadTraceArtifacts(runID, opts.Project)
if err != nil {
log.Fatalf("Failed to download artifacts: %v", err)
}
traces, err := findTraceInfos(destDir, runID)
if err != nil {
log.Fatalf("Failed to find traces: %v", err)
}
if len(traces) == 0 {
log.Info("No trace files found in the downloaded artifacts.")
log.Info("Traces are only generated for failing tests (retain-on-failure).")
return
}
projects := groupByProject(traces)
if opts.List || opts.NoOpen {
printTraceList(traces, projects)
fmt.Printf("\nTraces downloaded to: %s\n", destDir)
return
}
if len(traces) == 1 {
openTraces(traces)
return
}
for {
selected := selectTraces(traces, projects)
if len(selected) == 0 {
return
}
openTraces(selected)
}
}
// resolveRunID determines the run ID from the provided arguments and options.
func resolveRunID(args []string, opts *TraceOptions) (string, error) {
if len(args) == 1 {
return parseRunIDFromArg(args[0])
}
if opts.PR != "" {
return findLatestRunForPR(opts.PR)
}
branch := opts.Branch
if branch == "" {
var err error
branch, err = git.GetCurrentBranch()
if err != nil {
return "", fmt.Errorf("failed to get current branch: %w", err)
}
if branch == "" {
return "", fmt.Errorf("detached HEAD; specify a --branch, --pr, or run ID")
}
log.Infof("Using current branch: %s", branch)
}
return findLatestRunForBranch(branch)
}
var runURLPattern = regexp.MustCompile(`/actions/runs/(\d+)`)
// parseRunIDFromArg extracts a run ID from either a numeric string or a full URL.
func parseRunIDFromArg(arg string) (string, error) {
if matched, _ := regexp.MatchString(`^\d+$`, arg); matched {
return arg, nil
}
matches := runURLPattern.FindStringSubmatch(arg)
if matches != nil {
return matches[1], nil
}
return "", fmt.Errorf("could not parse run ID from %q; expected a numeric ID or GitHub Actions URL", arg)
}
// findLatestRunForBranch finds the most recent Playwright workflow run for a branch.
func findLatestRunForBranch(branch string) (string, error) {
log.Infof("Looking up latest Playwright run for branch: %s", branch)
cmd := exec.Command("gh", "run", "list",
"--workflow", playwrightWorkflow,
"--branch", branch,
"--limit", "1",
"--json", "databaseId,status,conclusion,headBranch,url",
)
output, err := cmd.Output()
if err != nil {
return "", ghError(err, "gh run list failed")
}
var runs []ghRun
if err := json.Unmarshal(output, &runs); err != nil {
return "", fmt.Errorf("failed to parse run list: %w", err)
}
if len(runs) == 0 {
return "", fmt.Errorf("no Playwright runs found for branch %q", branch)
}
run := runs[0]
log.Infof("Found run: %s (status: %s, conclusion: %s)", run.URL, run.Status, run.Conclusion)
return fmt.Sprintf("%d", run.DatabaseID), nil
}
// findLatestRunForPR finds the most recent Playwright workflow run for a PR.
func findLatestRunForPR(prNumber string) (string, error) {
log.Infof("Looking up branch for PR #%s", prNumber)
cmd := exec.Command("gh", "pr", "view", prNumber,
"--json", "headRefName",
"--jq", ".headRefName",
)
output, err := cmd.Output()
if err != nil {
return "", ghError(err, "gh pr view failed")
}
branch := strings.TrimSpace(string(output))
if branch == "" {
return "", fmt.Errorf("could not determine branch for PR #%s", prNumber)
}
log.Infof("PR #%s is on branch: %s", prNumber, branch)
return findLatestRunForBranch(branch)
}
// downloadTraceArtifacts downloads playwright trace artifacts for a run.
// Returns the path to the download directory.
func downloadTraceArtifacts(runID string, project string) (string, error) {
cacheKey := runID
if project != "" {
cacheKey = runID + "-" + project
}
destDir := filepath.Join(os.TempDir(), "ods-traces", cacheKey)
// Reuse a previous download if traces exist
if info, err := os.Stat(destDir); err == nil && info.IsDir() {
traces, _ := findTraces(destDir)
if len(traces) > 0 {
log.Infof("Using cached download at %s", destDir)
return destDir, nil
}
_ = os.RemoveAll(destDir)
}
if err := os.MkdirAll(destDir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory %s: %w", destDir, err)
}
ghArgs := []string{"run", "download", runID, "--dir", destDir}
if project != "" {
ghArgs = append(ghArgs, "--pattern", fmt.Sprintf("playwright-test-results-%s-*", project))
} else {
ghArgs = append(ghArgs, "--pattern", "playwright-test-results-*")
}
log.Infof("Downloading trace artifacts...")
log.Debugf("Running: gh %s", strings.Join(ghArgs, " "))
cmd := exec.Command("gh", ghArgs...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
_ = os.RemoveAll(destDir)
return "", fmt.Errorf("gh run download failed: %w\nMake sure the run ID is correct and the artifacts haven't expired (30 day retention)", err)
}
return destDir, nil
}
// findTraces recursively finds all trace.zip files under a directory.
func findTraces(root string) ([]string, error) {
var traces []string
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && info.Name() == "trace.zip" {
traces = append(traces, path)
}
return nil
})
return traces, err
}
// findTraceInfos walks the download directory and returns structured trace info.
// Expects: destDir/{artifact-dir}/{test-dir}/trace.zip
func findTraceInfos(destDir, runID string) ([]traceInfo, error) {
var traces []traceInfo
err := filepath.Walk(destDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() || info.Name() != "trace.zip" {
return nil
}
rel, _ := filepath.Rel(destDir, path)
parts := strings.SplitN(rel, string(filepath.Separator), 3)
artifactDir := ""
testDir := filepath.Base(filepath.Dir(path))
if len(parts) >= 2 {
artifactDir = parts[0]
testDir = parts[1]
}
traces = append(traces, traceInfo{
Path: path,
Project: extractProject(artifactDir, runID),
TestDir: testDir,
})
return nil
})
sort.Slice(traces, func(i, j int) bool {
pi, pj := projectSortKey(traces[i].Project), projectSortKey(traces[j].Project)
if pi != pj {
return pi < pj
}
return traces[i].TestDir < traces[j].TestDir
})
return traces, err
}
// extractProject derives a project group from an artifact directory name.
// e.g. "playwright-test-results-admin-12345" -> "admin"
//
// "playwright-test-results-admin-shard-1-12345" -> "admin-shard-1"
func extractProject(artifactDir, runID string) string {
name := strings.TrimPrefix(artifactDir, "playwright-test-results-")
name = strings.TrimSuffix(name, "-"+runID)
if name == "" {
return artifactDir
}
return name
}
// projectSortKey returns a sort-friendly key that orders admin < exclusive < lite.
func projectSortKey(project string) string {
switch {
case strings.HasPrefix(project, "admin"):
return "0-" + project
case strings.HasPrefix(project, "exclusive"):
return "1-" + project
case strings.HasPrefix(project, "lite"):
return "2-" + project
default:
return "3-" + project
}
}
// groupByProject returns an ordered list of unique project names found in traces.
func groupByProject(traces []traceInfo) []string {
seen := map[string]bool{}
var projects []string
for _, t := range traces {
if !seen[t.Project] {
seen[t.Project] = true
projects = append(projects, t.Project)
}
}
sort.Slice(projects, func(i, j int) bool {
return projectSortKey(projects[i]) < projectSortKey(projects[j])
})
return projects
}
// printTraceList displays traces grouped by project.
func printTraceList(traces []traceInfo, projects []string) {
fmt.Printf("\nFound %d trace(s) across %d project(s):\n", len(traces), len(projects))
idx := 1
for _, proj := range projects {
count := 0
for _, t := range traces {
if t.Project == proj {
count++
}
}
fmt.Printf("\n %s (%d):\n", proj, count)
for _, t := range traces {
if t.Project == proj {
fmt.Printf(" [%2d] %s\n", idx, t.TestDir)
idx++
}
}
}
}
// selectTraces tries the TUI picker first, falling back to a plain-text
// prompt when the terminal cannot be initialised (e.g. piped output).
func selectTraces(traces []traceInfo, projects []string) []traceInfo {
// Build picker groups in the same order as the sorted traces slice.
var groups []tui.PickerGroup
for _, proj := range projects {
var items []string
for _, t := range traces {
if t.Project == proj {
items = append(items, t.TestDir)
}
}
groups = append(groups, tui.PickerGroup{Label: proj, Items: items})
}
indices, err := tui.Pick(groups)
if err != nil {
// Terminal not available — fall back to text prompt
log.Debugf("TUI picker unavailable: %v", err)
printTraceList(traces, projects)
return promptTraceSelection(traces, projects)
}
if indices == nil {
return nil // user cancelled
}
selected := make([]traceInfo, len(indices))
for i, idx := range indices {
selected[i] = traces[idx]
}
return selected
}
// promptTraceSelection asks the user which traces to open via plain text.
// Accepts numbers (1,3,5), ranges (1-5), "all", or a project name.
func promptTraceSelection(traces []traceInfo, projects []string) []traceInfo {
fmt.Printf("\nOpen which traces? (e.g. 1,3,5 | 1-5 | all | %s): ", strings.Join(projects, " | "))
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
log.Fatalf("Failed to read input: %v", err)
}
input = strings.TrimSpace(input)
if input == "" || strings.EqualFold(input, "all") {
return traces
}
// Check if input matches a project name
for _, proj := range projects {
if strings.EqualFold(input, proj) {
var selected []traceInfo
for _, t := range traces {
if t.Project == proj {
selected = append(selected, t)
}
}
return selected
}
}
// Parse as number/range selection
indices := parseTraceSelection(input, len(traces))
if len(indices) == 0 {
log.Warn("No valid selection; opening all traces")
return traces
}
selected := make([]traceInfo, len(indices))
for i, idx := range indices {
selected[i] = traces[idx]
}
return selected
}
// parseTraceSelection parses a comma-separated list of numbers and ranges into
// 0-based indices. Input is 1-indexed (matches display). Out-of-range values
// are silently ignored.
func parseTraceSelection(input string, max int) []int {
var result []int
seen := map[int]bool{}
for _, part := range strings.Split(input, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if idx := strings.Index(part, "-"); idx > 0 {
lo, err1 := strconv.Atoi(strings.TrimSpace(part[:idx]))
hi, err2 := strconv.Atoi(strings.TrimSpace(part[idx+1:]))
if err1 != nil || err2 != nil {
continue
}
for i := lo; i <= hi; i++ {
zi := i - 1
if zi >= 0 && zi < max && !seen[zi] {
result = append(result, zi)
seen[zi] = true
}
}
} else {
n, err := strconv.Atoi(part)
if err != nil {
continue
}
zi := n - 1
if zi >= 0 && zi < max && !seen[zi] {
result = append(result, zi)
seen[zi] = true
}
}
}
return result
}
// openTraces opens the selected traces with playwright show-trace,
// running npx from the web/ directory to use the project's Playwright version.
func openTraces(traces []traceInfo) {
tracePaths := make([]string, len(traces))
for i, t := range traces {
tracePaths[i] = t.Path
}
args := append([]string{"playwright", "show-trace"}, tracePaths...)
log.Infof("Opening %d trace(s) with playwright show-trace...", len(traces))
cmd := exec.Command("npx", args...)
// Run from web/ to pick up the locally-installed Playwright version
if root, err := paths.GitRoot(); err == nil {
cmd.Dir = filepath.Join(root, "web")
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = os.Stdin
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
// Normal exit (e.g. user closed the window) — just log and return
// so the picker loop can continue.
log.Debugf("playwright exited with code %d", exitErr.ExitCode())
return
}
log.Errorf("playwright show-trace failed: %v\nMake sure Playwright is installed (npx playwright install)", err)
}
}
// ghError wraps a gh CLI error with stderr output.
func ghError(err error, msg string) error {
if exitErr, ok := err.(*exec.ExitError); ok {
return fmt.Errorf("%s: %w: %s", msg, err, string(exitErr.Stderr))
}
return fmt.Errorf("%s: %w", msg, err)
}

View File

@@ -3,13 +3,19 @@ module github.com/onyx-dot-app/onyx/tools/ods
go 1.26.0
require (
github.com/gdamore/tcell/v2 v2.13.8
github.com/jmelahman/tag v0.5.2
github.com/sirupsen/logrus v1.9.3
github.com/sirupsen/logrus v1.9.4
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.10
)
require (
github.com/gdamore/encoding v1.0.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
golang.org/x/sys v0.39.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/term v0.41.0 // indirect
golang.org/x/text v0.35.0 // indirect
)

View File

@@ -1,30 +1,68 @@
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
github.com/gdamore/tcell/v2 v2.13.8 h1:Mys/Kl5wfC/GcC5Cx4C2BIQH9dbnhnkPgS9/wF3RlfU=
github.com/gdamore/tcell/v2 v2.13.8/go.mod h1:+Wfe208WDdB7INEtCsNrAN6O2m+wsTPk1RAovjaILlo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jmelahman/tag v0.5.2 h1:g6A/aHehu5tkA31mPoDsXBNr1FigZ9A82Y8WVgb/WsM=
github.com/jmelahman/tag v0.5.2/go.mod h1:qmuqk19B1BKkpcg3kn7l/Eey+UqucLxgOWkteUGiG4Q=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,419 @@
package tui
import (
"fmt"
"github.com/gdamore/tcell/v2"
)
// PickerGroup represents a labelled group of selectable items.
type PickerGroup struct {
Label string
Items []string
}
// entry is a single row in the picker (either a group header or an item).
type entry struct {
label string
isHeader bool
selected bool
groupIdx int
flatIdx int // index across all items (ignoring headers), -1 for headers
}
// Pick shows a full-screen grouped multi-select picker.
// All items start deselected. Returns the flat indices of selected items
// (0-based, spanning all groups in order). Returns nil if cancelled.
// Returns a non-nil error if the terminal cannot be initialised, in which
// case the caller should fall back to a simpler prompt.
func Pick(groups []PickerGroup) ([]int, error) {
screen, err := tcell.NewScreen()
if err != nil {
return nil, err
}
if err := screen.Init(); err != nil {
return nil, err
}
defer screen.Fini()
entries := buildEntries(groups)
totalItems := countItems(entries)
cursor := firstSelectableIndex(entries)
offset := 0
for {
w, h := screen.Size()
selectedCount := countSelected(entries)
drawPicker(screen, entries, groups, cursor, offset, w, h, selectedCount, totalItems)
screen.Show()
ev := screen.PollEvent()
switch ev := ev.(type) {
case *tcell.EventResize:
screen.Sync()
case *tcell.EventKey:
switch action := keyAction(ev); action {
case actionQuit:
return nil, nil
case actionConfirm:
if countSelected(entries) > 0 {
return collectSelected(entries), nil
}
case actionUp:
if cursor > 0 {
cursor--
}
case actionDown:
if cursor < len(entries)-1 {
cursor++
}
case actionTop:
cursor = 0
case actionBottom:
if len(entries) == 0 {
cursor = 0
} else {
cursor = len(entries) - 1
}
case actionPageUp:
listHeight := h - headerLines - footerLines
cursor -= listHeight
if cursor < 0 {
cursor = 0
}
case actionPageDown:
listHeight := h - headerLines - footerLines
cursor += listHeight
if cursor >= len(entries) {
cursor = len(entries) - 1
}
case actionToggle:
toggleAtCursor(entries, cursor)
case actionAll:
setAll(entries, true)
case actionNone:
setAll(entries, false)
}
// Keep the cursor visible
listHeight := h - headerLines - footerLines
if listHeight < 1 {
listHeight = 1
}
if cursor < offset {
offset = cursor
}
if cursor >= offset+listHeight {
offset = cursor - listHeight + 1
}
}
}
}
// --- actions ----------------------------------------------------------------
type action int
const (
actionNoop action = iota
actionQuit
actionConfirm
actionUp
actionDown
actionTop
actionBottom
actionPageUp
actionPageDown
actionToggle
actionAll
actionNone
)
func keyAction(ev *tcell.EventKey) action {
switch ev.Key() {
case tcell.KeyEscape, tcell.KeyCtrlC:
return actionQuit
case tcell.KeyEnter:
return actionConfirm
case tcell.KeyUp:
return actionUp
case tcell.KeyDown:
return actionDown
case tcell.KeyHome:
return actionTop
case tcell.KeyEnd:
return actionBottom
case tcell.KeyPgUp:
return actionPageUp
case tcell.KeyPgDn:
return actionPageDown
case tcell.KeyRune:
switch ev.Rune() {
case 'q':
return actionQuit
case ' ':
return actionToggle
case 'j':
return actionDown
case 'k':
return actionUp
case 'g':
return actionTop
case 'G':
return actionBottom
case 'a':
return actionAll
case 'n':
return actionNone
}
}
return actionNoop
}
// --- data helpers ------------------------------------------------------------
func buildEntries(groups []PickerGroup) []entry {
var entries []entry
flat := 0
for gi, g := range groups {
entries = append(entries, entry{
label: g.Label,
isHeader: true,
groupIdx: gi,
flatIdx: -1,
})
for _, item := range g.Items {
entries = append(entries, entry{
label: item,
isHeader: false,
selected: false,
groupIdx: gi,
flatIdx: flat,
})
flat++
}
}
return entries
}
func firstSelectableIndex(entries []entry) int {
for i, e := range entries {
if !e.isHeader {
return i
}
}
return 0
}
func countItems(entries []entry) int {
n := 0
for _, e := range entries {
if !e.isHeader {
n++
}
}
return n
}
func countSelected(entries []entry) int {
n := 0
for _, e := range entries {
if !e.isHeader && e.selected {
n++
}
}
return n
}
func collectSelected(entries []entry) []int {
var result []int
for _, e := range entries {
if !e.isHeader && e.selected {
result = append(result, e.flatIdx)
}
}
return result
}
func toggleAtCursor(entries []entry, cursor int) {
if cursor < 0 || cursor >= len(entries) {
return
}
e := entries[cursor]
if e.isHeader {
// Toggle entire group: if all selected -> deselect all, else select all
allSelected := true
for _, e2 := range entries {
if !e2.isHeader && e2.groupIdx == e.groupIdx && !e2.selected {
allSelected = false
break
}
}
for i := range entries {
if !entries[i].isHeader && entries[i].groupIdx == e.groupIdx {
entries[i].selected = !allSelected
}
}
} else {
entries[cursor].selected = !entries[cursor].selected
}
}
func setAll(entries []entry, selected bool) {
for i := range entries {
if !entries[i].isHeader {
entries[i].selected = selected
}
}
}
// --- drawing ----------------------------------------------------------------
const (
headerLines = 2 // title + blank line
footerLines = 2 // blank line + keybinds
)
var (
styleDefault = tcell.StyleDefault
styleTitle = tcell.StyleDefault.Bold(true)
styleGroup = tcell.StyleDefault.Bold(true).Foreground(tcell.ColorTeal)
styleGroupCur = tcell.StyleDefault.Bold(true).Foreground(tcell.ColorTeal).Reverse(true)
styleCheck = tcell.StyleDefault.Foreground(tcell.ColorGreen).Bold(true)
styleUncheck = tcell.StyleDefault.Dim(true)
styleItem = tcell.StyleDefault
styleItemCur = tcell.StyleDefault.Bold(true).Underline(true)
styleCheckCur = tcell.StyleDefault.Foreground(tcell.ColorGreen).Bold(true).Underline(true)
styleUncheckCur = tcell.StyleDefault.Dim(true).Underline(true)
styleFooter = tcell.StyleDefault.Dim(true)
)
func drawPicker(
screen tcell.Screen,
entries []entry,
groups []PickerGroup,
cursor, offset, w, h, selectedCount, totalItems int,
) {
screen.Clear()
// Title
title := fmt.Sprintf(" Select traces to open (%d/%d selected)", selectedCount, totalItems)
drawLine(screen, 0, 0, w, title, styleTitle)
// List area
listHeight := h - headerLines - footerLines
if listHeight < 1 {
listHeight = 1
}
for i := 0; i < listHeight; i++ {
ei := offset + i
if ei >= len(entries) {
break
}
y := headerLines + i
renderEntry(screen, entries, groups, ei, cursor, w, y)
}
// Scrollbar hint
if len(entries) > listHeight {
drawScrollbar(screen, w-1, headerLines, listHeight, offset, len(entries))
}
// Footer
footerY := h - 1
footer := " ↑/↓ move space toggle a all n none enter open q/esc quit"
drawLine(screen, 0, footerY, w, footer, styleFooter)
}
func renderEntry(screen tcell.Screen, entries []entry, groups []PickerGroup, ei, cursor, w, y int) {
e := entries[ei]
isCursor := ei == cursor
if e.isHeader {
groupSelected := 0
groupTotal := 0
for _, e2 := range entries {
if !e2.isHeader && e2.groupIdx == e.groupIdx {
groupTotal++
if e2.selected {
groupSelected++
}
}
}
label := fmt.Sprintf(" %s (%d/%d)", e.label, groupSelected, groupTotal)
style := styleGroup
if isCursor {
style = styleGroupCur
}
drawLine(screen, 0, y, w, label, style)
return
}
// Item row: " [x] label" or " > [x] label"
prefix := " "
if isCursor {
prefix = " > "
}
check := "[ ]"
cStyle := styleUncheck
iStyle := styleItem
if isCursor {
cStyle = styleUncheckCur
iStyle = styleItemCur
}
if e.selected {
check = "[x]"
cStyle = styleCheck
if isCursor {
cStyle = styleCheckCur
}
}
x := drawStr(screen, 0, y, w, prefix, iStyle)
x = drawStr(screen, x, y, w, check, cStyle)
drawStr(screen, x, y, w, " "+e.label, iStyle)
}
func drawScrollbar(screen tcell.Screen, x, top, height, offset, total int) {
if total <= height || height < 1 {
return
}
thumbSize := max(1, height*height/total)
thumbPos := top + offset*height/total
for y := top; y < top+height; y++ {
ch := '│'
style := styleDefault.Dim(true)
if y >= thumbPos && y < thumbPos+thumbSize {
ch = '┃'
style = styleDefault
}
screen.SetContent(x, y, ch, nil, style)
}
}
// drawLine fills an entire row starting at x=startX, padding to width w.
func drawLine(screen tcell.Screen, startX, y, w int, s string, style tcell.Style) {
x := drawStr(screen, startX, y, w, s, style)
// Clear the rest of the line
for ; x < w; x++ {
screen.SetContent(x, y, ' ', nil, style)
}
}
// drawStr writes a string at (x, y) up to maxX and returns the next x position.
func drawStr(screen tcell.Screen, x, y, maxX int, s string, style tcell.Style) int {
for _, ch := range s {
if x >= maxX {
break
}
screen.SetContent(x, y, ch, nil, style)
x++
}
return x
}

View File

@@ -0,0 +1,22 @@
import { cn } from "@opal/utils";
import type { IconProps } from "@opal/types";
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 37 46"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className={cn(className, "text-[#33C19E] dark:text-white")}
{...props}
>
<title>Bifrost</title>
<path
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
fill="currentColor"
/>
</svg>
);
export default SvgBifrost;

View File

@@ -24,6 +24,7 @@ export { default as SvgAzure } from "@opal/icons/azure";
export { default as SvgBarChart } from "@opal/icons/bar-chart";
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
export { default as SvgBell } from "@opal/icons/bell";
export { default as SvgBifrost } from "@opal/icons/bifrost";
export { default as SvgBlocks } from "@opal/icons/blocks";
export { default as SvgBookOpen } from "@opal/icons/book-open";
export { default as SvgBookmark } from "@opal/icons/bookmark";

24
web/package-lock.json generated
View File

@@ -7901,7 +7901,9 @@
}
},
"node_modules/anymatch/node_modules/picomatch": {
"version": "2.3.1",
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
"license": "MIT",
"engines": {
"node": ">=8.6"
@@ -10701,7 +10703,9 @@
"license": "MIT"
},
"node_modules/handlebars": {
"version": "4.7.8",
"version": "4.7.9",
"resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.9.tgz",
"integrity": "sha512-4E71E0rpOaQuJR2A3xDZ+GM1HyWYv1clR58tC8emQNeQe3RH7MAzSbat+V0wG78LQBo6m6bzSG/L4pBuCsgnUQ==",
"dev": true,
"license": "MIT",
"dependencies": {
@@ -12555,7 +12559,9 @@
}
},
"node_modules/jest-util/node_modules/picomatch": {
"version": "2.3.1",
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
"dev": true,
"license": "MIT",
"engines": {
@@ -13881,7 +13887,9 @@
}
},
"node_modules/micromatch/node_modules/picomatch": {
"version": "2.3.1",
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
"license": "MIT",
"engines": {
"node": ">=8.6"
@@ -15001,7 +15009,9 @@
"license": "ISC"
},
"node_modules/picomatch": {
"version": "4.0.3",
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"license": "MIT",
"engines": {
"node": ">=12"
@@ -15889,7 +15899,9 @@
}
},
"node_modules/readdirp/node_modules/picomatch": {
"version": "2.3.1",
"version": "2.3.2",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
"license": "MIT",
"engines": {
"node": ">=8.6"

View File

@@ -30,8 +30,11 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import Text from "@/refresh-components/texts/Text";
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
import Message from "@/refresh-components/messages/Message";
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
import { useBillingInformation } from "@/hooks/useBillingInformation";
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
const route = ADMIN_ROUTES.API_KEYS;
@@ -44,6 +47,11 @@ function Main() {
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
const canCreateKeys = useCloudSubscription();
const { data: billingData } = useBillingInformation();
const isTrialing =
billingData !== undefined &&
hasActiveSubscription(billingData) &&
billingData.status === BillingStatus.TRIALING;
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
@@ -75,6 +83,16 @@ function Main() {
const introSection = (
<div className="flex flex-col items-start gap-4">
{isTrialing && (
<Message
static
warning
close={false}
className="w-full"
text="Upgrade to a paid plan to create API keys."
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
)}
<Text as="p">
API Keys allow you to access Onyx APIs programmatically.
{canCreateKeys
@@ -85,23 +103,9 @@ function Main() {
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
Create API Key
</CreateButton>
) : (
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
<div className="flex items-center gap-1.5">
<Text as="p" text04>
Upgrade to a paid plan to create API keys.
</Text>
<Button
variant="none"
prominence="tertiary"
size="2xs"
icon={SvgInfo}
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
</div>
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
</div>
)}
) : isTrialing ? (
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
) : null}
</div>
);

View File

@@ -0,0 +1,387 @@
/**
* Tests for BillingPage handleBillingReturn retry logic.
*
* The retry logic retries claimLicense up to 3 times with 2s backoff
* when returning from a Stripe checkout session. This prevents the user
* from getting stranded when the Stripe webhook fires concurrently with
* the browser redirect and the license isn't ready yet.
*/
import React from "react";
import { render, screen, waitFor } from "@tests/setup/test-utils";
import { act } from "@testing-library/react";
// ---- Stable mock objects (must be named with mock* prefix for jest hoisting) ----
// useRouter and useSearchParams must return the SAME reference each call, otherwise
// React's useEffect sees them as changed and re-runs the effect on every render.
const mockRouter = {
replace: jest.fn() as jest.Mock,
refresh: jest.fn() as jest.Mock,
};
const mockSearchParams = {
get: jest.fn() as jest.Mock,
};
const mockClaimLicense = jest.fn() as jest.Mock;
const mockRefreshBilling = jest.fn() as jest.Mock;
const mockRefreshLicense = jest.fn() as jest.Mock;
// ---- Mocks ----
jest.mock("next/navigation", () => ({
useRouter: () => mockRouter,
useSearchParams: () => mockSearchParams,
}));
jest.mock("@/layouts/settings-layouts", () => ({
Root: ({ children }: { children: React.ReactNode }) => (
<div data-testid="settings-root">{children}</div>
),
Header: () => <div data-testid="settings-header" />,
Body: ({ children }: { children: React.ReactNode }) => (
<div data-testid="settings-body">{children}</div>
),
}));
jest.mock("@/layouts/general-layouts", () => ({
Section: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
jest.mock("@opal/icons", () => ({
SvgArrowUpCircle: () => <svg />,
SvgWallet: () => <svg />,
}));
jest.mock("./PlansView", () => ({
__esModule: true,
default: () => <div data-testid="plans-view" />,
}));
jest.mock("./CheckoutView", () => ({
__esModule: true,
default: () => <div data-testid="checkout-view" />,
}));
jest.mock("./BillingDetailsView", () => ({
__esModule: true,
default: () => <div data-testid="billing-details-view" />,
}));
jest.mock("./LicenseActivationCard", () => ({
__esModule: true,
default: () => <div data-testid="license-activation-card" />,
}));
jest.mock("@/refresh-components/messages/Message", () => ({
__esModule: true,
default: ({
text,
description,
onClose,
}: {
text: string;
description?: string;
onClose?: () => void;
}) => (
<div data-testid="activating-banner">
<span data-testid="activating-banner-text">{text}</span>
{description && (
<span data-testid="activating-banner-description">{description}</span>
)}
{onClose && (
<button data-testid="activating-banner-close" onClick={onClose}>
Close
</button>
)}
</div>
),
}));
jest.mock("@/lib/billing", () => ({
useBillingInformation: jest.fn(),
useLicense: jest.fn(),
hasActiveSubscription: jest.fn().mockReturnValue(false),
claimLicense: (...args: unknown[]) => mockClaimLicense(...args),
}));
jest.mock("@/lib/constants", () => ({
NEXT_PUBLIC_CLOUD_ENABLED: false,
}));
// ---- Import after mocks ----
import BillingPage from "./page";
import { useBillingInformation, useLicense } from "@/lib/billing";
// ---- Test helpers ----
function setupHooks() {
(useBillingInformation as jest.Mock).mockReturnValue({
data: null,
isLoading: false,
error: null,
refresh: mockRefreshBilling,
});
(useLicense as jest.Mock).mockReturnValue({
data: null,
isLoading: false,
refresh: mockRefreshLicense,
});
}
// ---- Tests ----
describe("BillingPage — handleBillingReturn retry logic", () => {
beforeEach(() => {
jest.clearAllMocks();
jest.useFakeTimers();
setupHooks();
// Default: no billing-return params
mockSearchParams.get.mockReturnValue(null);
// Clear any activating state from prior tests
sessionStorage.clear();
});
afterEach(() => {
jest.useRealTimers();
jest.restoreAllMocks();
});
test("calls claimLicense once and refreshes on first-attempt success", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_test_123" : null
);
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
expect(mockClaimLicense).toHaveBeenCalledWith("cs_test_123");
});
expect(mockRouter.refresh).toHaveBeenCalled();
expect(mockRefreshBilling).toHaveBeenCalled();
// URL cleaned up after checkout return
expect(mockRouter.replace).toHaveBeenCalledWith("/admin/billing", {
scroll: false,
});
});
test("retries after first failure and succeeds on second attempt", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_retry_test" : null
);
mockClaimLicense
.mockRejectedValueOnce(new Error("License not ready yet"))
.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(2);
});
// On eventual success, router and billing should be refreshed
expect(mockRouter.refresh).toHaveBeenCalled();
expect(mockRefreshBilling).toHaveBeenCalled();
});
test("retries all 3 times then navigates to details even on total failure", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_all_fail" : null
);
// All 3 attempts fail
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(3);
});
// User stays on plans view with the activating banner
await waitFor(() => {
expect(screen.getByTestId("plans-view")).toBeInTheDocument();
});
// refreshBilling still fires so billing state is up to date
expect(mockRefreshBilling).toHaveBeenCalled();
// Failure is logged
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("Failed to sync license after billing return"),
expect.any(Error)
);
consoleSpy.mockRestore();
});
test("calls claimLicense without session_id on portal_return", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "portal_return" ? "true" : null
);
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
// No session_id for portal returns — called with undefined
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
});
expect(mockRefreshBilling).toHaveBeenCalled();
});
test("does not call claimLicense when no billing-return params present", async () => {
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(mockClaimLicense).not.toHaveBeenCalled();
});
test("shows activating banner and sets sessionStorage on 3x retry failure", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_all_fail" : null
);
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
});
expect(screen.getByTestId("activating-banner-text")).toHaveTextContent(
"Your license is still activating"
);
expect(
sessionStorage.getItem("billing_license_activating_until")
).not.toBeNull();
consoleSpy.mockRestore();
});
test("banner not rendered when no activating state", async () => {
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
});
test("banner shown on mount when sessionStorage key is set and not expired", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
// Flush React effects — banner is visible from lazy state init, no timer advancement needed
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
});
test("banner not shown on mount when sessionStorage key is expired", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() - 1000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
});
test("poll calls claimLicense after 15s and clears banner on success", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
// Poll attempt succeeds
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
// Flush effects — banner visible from lazy state init
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
// Advance past one poll interval (15s)
await act(async () => {
await jest.advanceTimersByTimeAsync(15_000);
});
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
expect(mockRefreshBilling).toHaveBeenCalled();
expect(mockRefreshLicense).toHaveBeenCalled();
expect(mockRouter.refresh).toHaveBeenCalled();
});
test("close button removes banner and clears sessionStorage", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
// Flush effects — banner visible from lazy state init
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
const closeButton = screen.getByTestId("activating-banner-close");
await act(async () => {
closeButton.click();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
});
});

View File

@@ -17,6 +17,7 @@ import {
} from "@/lib/billing";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { useUser } from "@/providers/UserProvider";
import Message from "@/refresh-components/messages/Message";
import PlansView from "./PlansView";
import CheckoutView from "./CheckoutView";
@@ -24,6 +25,9 @@ import BillingDetailsView from "./BillingDetailsView";
import LicenseActivationCard from "./LicenseActivationCard";
import "./billing.css";
// sessionStorage key: value is a unix-ms expiry timestamp
const BILLING_ACTIVATING_KEY = "billing_license_activating_until";
// ----------------------------------------------------------------------------
// Types
// ----------------------------------------------------------------------------
@@ -105,6 +109,7 @@ export default function BillingPage() {
const [transitionType, setTransitionType] = useState<
"expand" | "collapse" | "fade"
>("fade");
const [isActivating, setIsActivating] = useState<boolean>(false);
const {
data: billingData,
@@ -155,6 +160,17 @@ export default function BillingPage() {
view,
]);
// Read activating state from sessionStorage after mount (avoids SSR hydration mismatch)
useEffect(() => {
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
if (!raw) return;
if (Number(raw) > Date.now()) {
setIsActivating(true);
} else {
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
}
}, []);
// Show license activation card when there's a Stripe error
useEffect(() => {
if (hasStripeError && !showLicenseActivationInput) {
@@ -172,24 +188,96 @@ export default function BillingPage() {
router.replace("/admin/billing", { scroll: false });
let cancelled = false;
const handleBillingReturn = async () => {
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
try {
// After checkout, exchange session_id for license; after portal, re-sync license
await claimLicense(sessionId ?? undefined);
refreshLicense();
// Refresh the page to update settings (including ee_features_enabled)
router.refresh();
// Navigate to billing details now that the license is active
changeView("details");
} catch (error) {
console.error("Failed to sync license after billing return:", error);
// Retry up to 3 times with 2s backoff. The license may not be available
// immediately if the Stripe webhook hasn't finished processing yet
// (redirect and webhook fire nearly simultaneously).
let lastError: Error | null = null;
for (let attempt = 0; attempt < 3; attempt++) {
if (cancelled) return;
try {
// After checkout, exchange session_id for license; after portal, re-sync license
await claimLicense(sessionId ?? undefined);
if (cancelled) return;
refreshLicense();
// Refresh the page to update settings (including ee_features_enabled)
router.refresh();
// Navigate to billing details now that the license is active
changeView("details");
lastError = null;
break;
} catch (err) {
lastError = err instanceof Error ? err : new Error("Unknown error");
if (attempt < 2) {
await new Promise((resolve) => setTimeout(resolve, 2000));
}
}
}
if (cancelled) return;
if (lastError) {
console.error(
"Failed to sync license after billing return:",
lastError
);
// Show an activating banner on the plans view and keep retrying in the background.
sessionStorage.setItem(
BILLING_ACTIVATING_KEY,
String(Date.now() + 120_000)
);
setIsActivating(true);
changeView("plans");
}
}
refreshBilling();
if (!cancelled) refreshBilling();
};
handleBillingReturn();
}, [searchParams, router, refreshBilling, refreshLicense]);
return () => {
cancelled = true;
};
// changeView intentionally omitted: it only calls stable state setters and the
// effect runs at most once (when session_id/portal_return params are present).
}, [searchParams, router, refreshBilling, refreshLicense]); // eslint-disable-line react-hooks/exhaustive-deps
// Poll every 15s while activating, up to 2 minutes, to detect when the license arrives.
useEffect(() => {
if (!isActivating) return;
let requestInFlight = false;
const intervalId = setInterval(async () => {
if (requestInFlight) return;
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
if (!raw || Number(raw) <= Date.now()) {
// Expired — stop immediately without waiting for React cleanup
clearInterval(intervalId);
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
return;
}
requestInFlight = true;
try {
await claimLicense(undefined);
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
refreshLicense();
refreshBilling();
router.refresh();
changeView("details");
} catch (err) {
// License not ready yet — keep polling. Log so unexpected failures
// (network errors, 500s) are distinguishable from expected 404s.
console.debug("License activation poll: will retry", err);
} finally {
requestInFlight = false;
}
}, 15_000);
return () => clearInterval(intervalId);
}, [isActivating]); // eslint-disable-line react-hooks/exhaustive-deps
const handleRefresh = async () => {
await Promise.all([
@@ -386,6 +474,22 @@ export default function BillingPage() {
/>
<SettingsLayouts.Body>
<div className="flex flex-col items-center gap-6">
{isActivating && (
<Message
static
warning
large
text="Your license is still activating"
description="Your license is being processed. You'll be taken to billing details automatically once confirmed."
icon
close
onClose={() => {
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
}}
className="w-full"
/>
)}
{renderContent()}
{renderFooter()}
</div>

View File

@@ -1,11 +1,12 @@
"use client";
import { useState, useMemo } from "react";
import { useState, useMemo, useEffect } from "react";
import useSWR from "swr";
import Text from "@/refresh-components/texts/Text";
import { Select } from "@/refresh-components/cards";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
import {
@@ -17,9 +18,16 @@ import {
ImageGenerationConfigView,
setDefaultImageGenerationConfig,
unsetDefaultImageGenerationConfig,
deleteImageGenerationConfig,
} from "@/lib/configuration/imageConfigurationService";
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
import Message from "@/refresh-components/messages/Message";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { Button } from "@opal/components";
import { SvgSlash, SvgUnplug } from "@opal/icons";
const NO_DEFAULT_VALUE = "__none__";
export default function ImageGenerationContent() {
const {
@@ -47,6 +55,11 @@ export default function ImageGenerationContent() {
);
const [editConfig, setEditConfig] =
useState<ImageGenerationConfigView | null>(null);
const [disconnectProvider, setDisconnectProvider] =
useState<ImageProvider | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const connectedProviderIds = useMemo(() => {
return new Set(configs.map((c) => c.image_provider_id));
@@ -115,6 +128,29 @@ export default function ImageGenerationContent() {
modal.toggle(true);
};
const handleDisconnect = async () => {
if (!disconnectProvider) return;
try {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
await setDefaultImageGenerationConfig(replacementProviderId);
}
await deleteImageGenerationConfig(disconnectProvider.image_provider_id);
toast.success(`${disconnectProvider.title} disconnected`);
refetchConfigs();
refetchProviders();
} catch (error) {
console.error("Failed to disconnect image generation provider:", error);
toast.error(
error instanceof Error ? error.message : "Failed to disconnect"
);
} finally {
setDisconnectProvider(null);
setReplacementProviderId(null);
}
};
const handleModalSuccess = () => {
toast.success("Provider configured successfully");
setEditConfig(null);
@@ -130,6 +166,36 @@ export default function ImageGenerationContent() {
);
}
// Compute replacement options when disconnecting an active provider
const isDisconnectingDefault =
disconnectProvider &&
defaultConfig?.image_provider_id === disconnectProvider.image_provider_id;
// Group connected replacement models by provider (excluding the model being disconnected)
const replacementGroups = useMemo(() => {
if (!disconnectProvider) return [];
return IMAGE_PROVIDER_GROUPS.map((group) => ({
...group,
providers: group.providers.filter(
(p) =>
p.image_provider_id !== disconnectProvider.image_provider_id &&
connectedProviderIds.has(p.image_provider_id)
),
})).filter((g) => g.providers.length > 0);
}, [disconnectProvider, connectedProviderIds]);
const needsReplacement = !!isDisconnectingDefault;
const hasReplacements = replacementGroups.length > 0;
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && !replacementProviderId && hasReplacements) {
const firstGroup = replacementGroups[0];
const firstModel = firstGroup?.providers[0];
if (firstModel) setReplacementProviderId(firstModel.image_provider_id);
}
}, [disconnectProvider]); // eslint-disable-line react-hooks/exhaustive-deps
return (
<>
<div className="flex flex-col gap-6">
@@ -175,6 +241,11 @@ export default function ImageGenerationContent() {
onSelect={() => handleSelect(provider)}
onDeselect={() => handleDeselect(provider)}
onEdit={() => handleEdit(provider)}
onDisconnect={
getStatus(provider) !== "disconnected"
? () => setDisconnectProvider(provider)
: undefined
}
/>
))}
</div>
@@ -182,6 +253,105 @@ export default function ImageGenerationContent() {
))}
</div>
{disconnectProvider && (
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectProvider.title}`}
description="This will remove the stored credentials for this provider."
onClose={() => {
setDisconnectProvider(null);
setReplacementProviderId(null);
}}
submit={
<Button
variant="danger"
onClick={() => void handleDisconnect()}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</Button>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" text03>
<b>{disconnectProvider.title}</b> is currently the default
image generation model. Session history will be preserved.
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" text04>
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => setReplacementProviderId(v)}
>
<InputSelect.Trigger placeholder="Select a replacement model" />
<InputSelect.Content>
{replacementGroups.map((group) => (
<InputSelect.Group key={group.name}>
<InputSelect.Label>{group.name}</InputSelect.Label>
{group.providers.map((p) => (
<InputSelect.Item
key={p.image_provider_id}
value={p.image_provider_id}
icon={() => (
<ProviderIcon
provider={p.provider_name}
size={16}
/>
)}
>
{p.title}
</InputSelect.Item>
))}
</InputSelect.Group>
))}
<InputSelect.Separator />
<InputSelect.Item
value={NO_DEFAULT_VALUE}
icon={SvgSlash}
>
<span>
<b>No Default</b>
<span className="text-text-03">
{" "}
(Disable Image Generation)
</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" text03>
<b>{disconnectProvider.title}</b> is currently the default
image generation model.
</Text>
<Text as="p" text03>
Connect another provider to continue using image generation.
</Text>
</>
)
) : (
<>
<Text as="p" text03>
<b>{disconnectProvider.title}</b> models will no longer be used
to generate images.
</Text>
<Text as="p" text03>
Session history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
)}
{activeProvider && (
<modal.Provider>
<ImageGenerationConnectionModal

View File

@@ -23,6 +23,7 @@ import {
BedrockModelResponse,
LMStudioModelResponse,
LiteLLMProxyModelResponse,
BifrostModelResponse,
ModelConfiguration,
LLMProviderName,
BedrockFetchParams,
@@ -30,8 +31,9 @@ import {
LMStudioFetchParams,
OpenRouterFetchParams,
LiteLLMProxyFetchParams,
BifrostFetchParams,
} from "@/interfaces/llm";
import { SvgAws, SvgOpenrouter } from "@opal/icons";
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
// Aggregator providers that host models from multiple vendors
export const AGGREGATOR_PROVIDERS = new Set([
@@ -41,6 +43,7 @@ export const AGGREGATOR_PROVIDERS = new Set([
"ollama_chat",
"lm_studio",
"litellm_proxy",
"bifrost",
"vertex_ai",
]);
@@ -78,6 +81,7 @@ export const getProviderIcon = (
bedrock_converse: SvgAws,
openrouter: SvgOpenrouter,
litellm_proxy: LiteLLMIcon,
bifrost: SvgBifrost,
vertex_ai: GeminiIcon,
};
@@ -263,8 +267,11 @@ export const fetchOpenRouterModels = async (
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch {
// ignore JSON parsing errors
} catch (jsonError) {
console.warn(
"Failed to parse OpenRouter model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
@@ -319,8 +326,11 @@ export const fetchLMStudioModels = async (
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch {
// ignore JSON parsing errors
} catch (jsonError) {
console.warn(
"Failed to parse LM Studio model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
@@ -343,6 +353,64 @@ export const fetchLMStudioModels = async (
}
};
/**
* Fetches Bifrost models directly without any form state dependencies.
* Uses snake_case params to match API structure.
*/
export const fetchBifrostModels = async (
params: BifrostFetchParams
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
const apiBase = params.api_base;
if (!apiBase) {
return { models: [], error: "API Base is required" };
}
try {
const response = await fetch("/api/admin/llm/bifrost/available-models", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_base: apiBase,
api_key: params.api_key,
provider_name: params.provider_name,
}),
signal: params.signal,
});
if (!response.ok) {
let errorMessage = "Failed to fetch models";
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch (jsonError) {
console.warn(
"Failed to parse Bifrost model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
const data: BifrostModelResponse[] = await response.json();
const models: ModelConfiguration[] = data.map((modelData) => ({
name: modelData.name,
display_name: modelData.display_name,
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: modelData.supports_reasoning,
}));
return { models };
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
return { models: [], error: errorMessage };
}
};
/**
* Fetches LiteLLM Proxy models directly without any form state dependencies.
* Uses snake_case params to match API structure.
@@ -456,6 +524,13 @@ export const fetchModels = async (
provider_name: formValues.name,
signal,
});
case LLMProviderName.BIFROST:
return fetchBifrostModels({
api_base: formValues.api_base,
api_key: formValues.api_key,
provider_name: formValues.name,
signal,
});
default:
return { models: [], error: `Unknown provider: ${providerName}` };
}
@@ -469,6 +544,7 @@ export function canProviderFetchModels(providerName?: string) {
case LLMProviderName.LM_STUDIO:
case LLMProviderName.OPENROUTER:
case LLMProviderName.LITELLM_PROXY:
case LLMProviderName.BIFROST:
return true;
default:
return false;

View File

@@ -1,32 +1,25 @@
"use client";
import Image from "next/image";
import { useMemo, useState, useReducer } from "react";
import { useEffect, useMemo, useState, useReducer } from "react";
import { InfoIcon } from "@/components/icons/icons";
import Text from "@/refresh-components/texts/Text";
import { Select } from "@/refresh-components/cards";
import { Section } from "@/layouts/general-layouts";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Content } from "@opal/layouts";
import useSWR from "swr";
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
import { ThreeDotsLoader } from "@/components/Loading";
import { Callout } from "@/components/ui/callout";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { Disabled } from "@opal/core";
import { cn } from "@/lib/utils";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgCheckSquare,
SvgEdit,
SvgGlobe,
SvgOnyxLogo,
SvgX,
} from "@opal/icons";
import { toast } from "@/hooks/useToast";
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
import { Button as OpalButton } from "@opal/components";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
const route = ADMIN_ROUTES.WEB_SEARCH;
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import {
SEARCH_PROVIDERS_URL,
SEARCH_PROVIDER_DETAILS,
@@ -58,6 +51,10 @@ import {
} from "@/app/admin/configuration/web-search/WebProviderModalReducer";
import { connectProviderFlow } from "@/app/admin/configuration/web-search/connectProviderFlow";
const NO_DEFAULT_VALUE = "__none__";
const route = ADMIN_ROUTES.WEB_SEARCH;
interface WebSearchProviderView {
id: number;
name: string;
@@ -76,27 +73,151 @@ interface WebContentProviderView {
has_api_key: boolean;
}
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
isHovered: boolean;
onMouseEnter: () => void;
onMouseLeave: () => void;
children: React.ReactNode;
interface DisconnectTargetState {
id: number;
label: string;
category: "search" | "content";
providerType: string;
}
function HoverIconButton({
isHovered,
onMouseEnter,
onMouseLeave,
children,
...buttonProps
}: HoverIconButtonProps) {
function WebSearchDisconnectModal({
disconnectTarget,
searchProviders,
contentProviders,
replacementProviderId,
onReplacementChange,
onClose,
onDisconnect,
}: {
disconnectTarget: DisconnectTargetState;
searchProviders: WebSearchProviderView[];
contentProviders: WebContentProviderView[];
replacementProviderId: string | null;
onReplacementChange: (id: string | null) => void;
onClose: () => void;
onDisconnect: () => void;
}) {
const isSearch = disconnectTarget.category === "search";
// Determine if the target is currently the active/selected provider
const isActive = isSearch
? searchProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
false
: contentProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
false;
// Find other configured providers as replacements
const replacementOptions = isSearch
? searchProviders.filter(
(p) => p.id !== disconnectTarget.id && p.id > 0 && p.has_api_key
)
: contentProviders.filter(
(p) =>
p.id !== disconnectTarget.id &&
p.provider_type !== "onyx_web_crawler" &&
p.id > 0 &&
p.has_api_key
);
const needsReplacement = isActive;
const hasReplacements = replacementOptions.length > 0;
const getLabel = (p: { name: string; provider_type: string }) => {
if (isSearch) {
const details =
SEARCH_PROVIDER_DETAILS[p.provider_type as WebSearchProviderType];
return details?.label ?? p.name ?? p.provider_type;
}
const details = CONTENT_PROVIDER_DETAILS[p.provider_type];
return details?.label ?? p.name ?? p.provider_type;
};
const categoryLabel = isSearch ? "search engine" : "web crawler";
const featureLabel = isSearch ? "web search" : "web crawling";
const disableLabel = isSearch ? "Disable Web Search" : "Disable Web Crawling";
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && hasReplacements && !replacementProviderId) {
const first = replacementOptions[0];
if (first) onReplacementChange(String(first.id));
}
}, []); // eslint-disable-line react-hooks/exhaustive-deps
return (
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
{/* TODO(@raunakab): migrate to opal Button once HoverIconButtonProps typing is resolved */}
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
{children}
</Button>
</div>
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectTarget.label}`}
description="This will remove the stored credentials for this provider."
onClose={onClose}
submit={
<OpalButton
variant="danger"
onClick={onDisconnect}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</OpalButton>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" text03>
<b>{disconnectTarget.label}</b> is currently the active{" "}
{categoryLabel}. Search history will be preserved.
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" secondaryBody text03>
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => onReplacementChange(v)}
>
<InputSelect.Trigger placeholder="Select a replacement provider" />
<InputSelect.Content>
{replacementOptions.map((p) => (
<InputSelect.Item key={p.id} value={String(p.id)}>
{getLabel(p)}
</InputSelect.Item>
))}
<InputSelect.Separator />
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
<span>
<b>No Default</b>
<span className="text-text-03"> ({disableLabel})</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" text03>
<b>{disconnectTarget.label}</b> is currently the active{" "}
{categoryLabel}.
</Text>
<Text as="p" text03>
Connect another provider to continue using {featureLabel}.
</Text>
</>
)
) : (
<>
<Text as="p" text03>
{isSearch ? "Web search" : "Web crawling"} will no longer be routed
through <b>{disconnectTarget.label}</b>.
</Text>
<Text as="p" text03>
Search history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
);
}
@@ -105,6 +226,11 @@ export default function Page() {
WebProviderModalReducer,
initialWebProviderModalState
);
const [disconnectTarget, setDisconnectTarget] =
useState<DisconnectTargetState | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const [contentModal, dispatchContentModal] = useReducer(
WebProviderModalReducer,
initialWebProviderModalState
@@ -113,8 +239,6 @@ export default function Page() {
const [contentActivationError, setContentActivationError] = useState<
string | null
>(null);
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
const {
data: searchProvidersData,
error: searchProvidersError,
@@ -833,6 +957,67 @@ export default function Page() {
});
};
const handleDisconnectProvider = async () => {
if (!disconnectTarget) return;
const { id, category } = disconnectTarget;
try {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
const repId = Number(replacementProviderId);
const activateEndpoint =
category === "search"
? `/api/admin/web-search/search-providers/${repId}/activate`
: `/api/admin/web-search/content-providers/${repId}/activate`;
const activateResp = await fetch(activateEndpoint, {
method: "POST",
headers: { "Content-Type": "application/json" },
});
if (!activateResp.ok) {
const errorBody = await activateResp.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to activate replacement provider."
);
}
}
const response = await fetch(
`/api/admin/web-search/${category}-providers/${id}`,
{ method: "DELETE" }
);
if (!response.ok) {
const errorBody = await response.json().catch((parseErr) => {
console.error("Failed to parse disconnect error response:", parseErr);
return {};
});
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to disconnect provider."
);
}
toast.success(`${disconnectTarget.label} disconnected`);
await mutateSearchProviders();
await mutateContentProviders();
} catch (error) {
console.error("Failed to disconnect web search provider:", error);
const message =
error instanceof Error ? error.message : "Unexpected error occurred.";
if (category === "search") {
setActivationError(message);
} else {
setContentActivationError(message);
}
} finally {
setDisconnectTarget(null);
setReplacementProviderId(null);
}
};
return (
<>
<SettingsLayouts.Root>
@@ -894,149 +1079,79 @@ export default function Page() {
provider
);
const isActive = provider?.is_active ?? false;
const isHighlighted = isActive;
const providerId = provider?.id;
const canOpenModal =
isBuiltInSearchProviderType(providerType);
const buttonState = (() => {
if (!provider || !isConfigured) {
return {
label: "Connect",
disabled: false,
icon: "arrow" as const,
onClick: canOpenModal
const status: "disconnected" | "connected" | "selected" =
!isConfigured
? "disconnected"
: isActive
? "selected"
: "connected";
return (
<Select
key={`${key}-${providerType}`}
icon={() =>
logoSrc ? (
<Image
src={logoSrc}
alt={`${label} logo`}
width={16}
height={16}
/>
) : (
<SvgGlobe size={16} />
)
}
title={label}
description={subtitle}
status={status}
onConnect={
canOpenModal
? () => {
openSearchModal(providerType, provider);
setActivationError(null);
}
: undefined,
};
}
if (isActive) {
return {
label: "Current Default",
disabled: false,
icon: "check" as const,
onClick: providerId
: undefined
}
onSelect={
providerId
? () => {
void handleActivateSearchProvider(providerId);
}
: undefined
}
onDeselect={
providerId
? () => {
void handleDeactivateSearchProvider(providerId);
}
: undefined,
};
}
return {
label: "Set as Default",
disabled: false,
icon: "arrow-circle" as const,
onClick: providerId
? () => {
void handleActivateSearchProvider(providerId);
}
: undefined,
};
})();
const buttonKey = `search-${key}-${providerType}`;
const isButtonHovered = hoveredButtonKey === buttonKey;
const isCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleCardClick = () => {
if (isCardClickable) {
buttonState.onClick?.();
}
};
return (
<div
key={`${key}-${providerType}`}
onClick={isCardClickable ? handleCardClick : undefined}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
isHighlighted
? "border-action-link-05"
: "border-border-01",
isCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 px-2 py-1">
{renderLogo({
logoSrc,
alt: `${label} logo`,
size: 16,
isHighlighted,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
</div>
<div className="flex items-center justify-end gap-2">
{isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={() => {
if (!canOpenModal) return;
: undefined
}
onEdit={
isConfigured && canOpenModal
? () => {
openSearchModal(
providerType as WebSearchProviderType,
provider
);
}}
aria-label={`Edit ${label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isButtonHovered}
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<Disabled
disabled={
buttonState.disabled || !buttonState.onClick
}
>
<OpalButton
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</OpalButton>
</Disabled>
)}
</div>
</div>
: undefined
}
onDisconnect={
isConfigured && provider && provider.id > 0
? () =>
setDisconnectTarget({
id: provider.id,
label,
category: "search",
providerType,
})
: undefined
}
/>
);
}
)}
@@ -1076,161 +1191,81 @@ export default function Page() {
const isCurrentCrawler =
provider.provider_type === currentContentProviderType;
const buttonState = (() => {
if (!isConfigured) {
return {
label: "Connect",
icon: "arrow" as const,
disabled: false,
onClick: () => {
openContentModal(provider.provider_type, provider);
setContentActivationError(null);
},
};
}
const status: "disconnected" | "connected" | "selected" =
!isConfigured
? "disconnected"
: isCurrentCrawler
? "selected"
: "connected";
if (isCurrentCrawler) {
return {
label: "Current Crawler",
icon: "check" as const,
disabled: false,
onClick: () => {
void handleDeactivateContentProvider(
providerId,
provider.provider_type
);
},
};
}
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
return {
label: "Set as Default",
icon: "arrow-circle" as const,
disabled: !canActivate,
onClick: canActivate
? () => {
void handleActivateContentProvider(provider);
}
: undefined,
};
})();
const contentButtonKey = `content-${provider.provider_type}-${provider.id}`;
const isContentButtonHovered =
hoveredButtonKey === contentButtonKey;
const isContentCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleContentCardClick = () => {
if (isContentCardClickable) {
buttonState.onClick?.();
}
};
const contentLogoSrc =
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.logoSrc;
return (
<div
<Select
key={`${provider.provider_type}-${provider.id}`}
onClick={
isContentCardClickable
? handleContentCardClick
icon={() =>
contentLogoSrc ? (
<Image
src={contentLogoSrc}
alt={`${label} logo`}
width={16}
height={16}
/>
) : provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo size={16} />
) : (
<SvgGlobe size={16} />
)
}
title={label}
description={subtitle}
status={status}
selectedLabel="Current Crawler"
onConnect={() => {
openContentModal(provider.provider_type, provider);
setContentActivationError(null);
}}
onSelect={
canActivate
? () => {
void handleActivateContentProvider(provider);
}
: undefined
}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
isCurrentCrawler
? "border-action-link-05"
: "border-border-01",
isContentCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 px-2 py-1">
{renderLogo({
logoSrc:
CONTENT_PROVIDER_DETAILS[provider.provider_type]
?.logoSrc,
alt: `${label} logo`,
fallback:
provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo size={16} />
) : undefined,
size: 16,
isHighlighted: isCurrentCrawler,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
</div>
<div className="flex items-center justify-end gap-2">
{provider.provider_type !== "onyx_web_crawler" &&
isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={() => {
openContentModal(
provider.provider_type,
provider
);
}}
aria-label={`Edit ${label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isContentButtonHovered}
onMouseEnter={() =>
setHoveredButtonKey(contentButtonKey)
onDeselect={() => {
void handleDeactivateContentProvider(
providerId,
provider.provider_type
);
}}
onEdit={
provider.provider_type !== "onyx_web_crawler" &&
isConfigured
? () => {
openContentModal(provider.provider_type, provider);
}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<Disabled
disabled={
buttonState.disabled || !buttonState.onClick
}
>
<OpalButton
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</OpalButton>
</Disabled>
)}
</div>
</div>
: undefined
}
onDisconnect={
provider.provider_type !== "onyx_web_crawler" &&
isConfigured &&
provider.id > 0
? () =>
setDisconnectTarget({
id: provider.id,
label,
category: "content",
providerType: provider.provider_type,
})
: undefined
}
/>
);
})}
</div>
@@ -1238,6 +1273,21 @@ export default function Page() {
</SettingsLayouts.Body>
</SettingsLayouts.Root>
{disconnectTarget && (
<WebSearchDisconnectModal
disconnectTarget={disconnectTarget}
searchProviders={searchProviders}
contentProviders={combinedContentProviders}
replacementProviderId={replacementProviderId}
onReplacementChange={setReplacementProviderId}
onClose={() => {
setDisconnectTarget(null);
setReplacementProviderId(null);
}}
onDisconnect={() => void handleDisconnectProvider()}
/>
)}
<WebProviderSetupModal
isOpen={selectedProviderType !== null}
onClose={() => {

View File

@@ -19,6 +19,10 @@
background-color: var(--background-neutral-00);
border: 1px solid var(--status-error-05);
}
.input-error:focus:not(:active),
.input-error:focus-within:not(:active) {
box-shadow: inset 0px 0px 0px 2px var(--background-tint-04);
}
.input-disabled {
background-color: var(--background-neutral-03);

View File

@@ -5,6 +5,7 @@
* and support various display sizes.
*/
import React from "react";
import { SvgBifrost } from "@opal/icons";
import { render } from "@tests/setup/test-utils";
import { GithubIcon, GitbookIcon, ConfluenceIcon } from "./icons";
@@ -51,4 +52,15 @@ describe("Logo Icons", () => {
render(<GithubIcon size={100} className="custom-class" />);
}).not.toThrow();
});
test("renders the Bifrost icon with theme-aware colors", () => {
const { container } = render(
<SvgBifrost size={32} className="custom text-red-500 dark:text-black" />
);
const icon = container.querySelector("svg");
expect(icon).toBeInTheDocument();
expect(icon).toHaveClass("custom", "text-[#33C19E]", "dark:text-white");
expect(icon).not.toHaveClass("text-red-500", "dark:text-black");
});
});

View File

@@ -13,6 +13,7 @@ export enum LLMProviderName {
VERTEX_AI = "vertex_ai",
BEDROCK = "bedrock",
LITELLM_PROXY = "litellm_proxy",
BIFROST = "bifrost",
CUSTOM = "custom",
}
@@ -165,6 +166,21 @@ export interface LiteLLMProxyModelResponse {
model_name: string;
}
export interface BifrostFetchParams {
api_base?: string;
api_key?: string;
provider_name?: string;
signal?: AbortSignal;
}
export interface BifrostModelResponse {
name: string;
display_name: string;
max_input_tokens: number | null;
supports_image_input: boolean;
supports_reasoning: boolean;
}
export interface VertexAIFetchParams {
model_configurations?: ModelConfiguration[];
}
@@ -182,5 +198,6 @@ export type FetchModelsParams =
| OllamaFetchParams
| OpenRouterFetchParams
| LiteLLMProxyFetchParams
| BifrostFetchParams
| VertexAIFetchParams
| LMStudioFetchParams;

View File

@@ -37,6 +37,7 @@ export interface Settings {
// User Knowledge settings
user_knowledge_enabled?: boolean;
user_file_max_upload_size_mb?: number | null;
file_token_count_threshold_k?: number | null;
// Connector settings
show_extra_connectors?: boolean;
@@ -68,6 +69,12 @@ export interface Settings {
// Application version from the ONYX_VERSION env var on the server.
version?: string | null;
// Hard ceiling for user_file_max_upload_size_mb, derived from env var.
max_allowed_upload_size_mb?: number;
// Factory defaults for the restore button.
default_user_file_max_upload_size_mb?: number;
default_file_token_count_threshold_k?: number;
}
export enum NotificationType {

View File

@@ -53,6 +53,12 @@ export async function fetchVoicesByType(
return fetch(`/api/admin/voice/voices?provider_type=${providerType}`);
}
export async function deleteVoiceProvider(
providerId: number
): Promise<Response> {
return fetch(`${VOICE_PROVIDERS_URL}/${providerId}`, { method: "DELETE" });
}
export async function fetchLLMProviders(): Promise<Response> {
return fetch("/api/admin/llm/provider");
}

View File

@@ -1,5 +1,6 @@
import type { IconFunctionComponent } from "@opal/types";
import {
SvgBifrost,
SvgCpu,
SvgOpenai,
SvgClaude,
@@ -26,6 +27,7 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
[LLMProviderName.BIFROST]: SvgBifrost,
// fallback
[LLMProviderName.CUSTOM]: SvgServer,
@@ -42,6 +44,7 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
[LLMProviderName.OPENROUTER]: "OpenRouter",
[LLMProviderName.LM_STUDIO]: "LM Studio",
[LLMProviderName.BIFROST]: "Bifrost",
// fallback
[LLMProviderName.CUSTOM]: "Custom Models",
@@ -58,6 +61,7 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
[LLMProviderName.OPENROUTER]: "OpenRouter",
[LLMProviderName.LM_STUDIO]: "LM Studio",
[LLMProviderName.BIFROST]: "Bifrost",
// fallback
[LLMProviderName.CUSTOM]: "Other providers or self-hosted",

View File

@@ -85,8 +85,6 @@ function buildFileKey(file: File): string {
return `${file.size}|${namePrefix}`;
}
const DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = 50;
interface ProjectsContextType {
projects: Project[];
recentFiles: ProjectFile[];
@@ -341,21 +339,20 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
onFailure?: (failedTempIds: string[]) => void
): Promise<ProjectFile[]> => {
const rawMax = settingsContext?.settings?.user_file_max_upload_size_mb;
const maxUploadSizeMb =
rawMax && rawMax > 0 ? rawMax : DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB;
const maxUploadSizeBytes = maxUploadSizeMb * 1024 * 1024;
const oversizedFiles = files.filter(
(file) => file.size > maxUploadSizeBytes
);
const validFiles = files.filter(
(file) => file.size <= maxUploadSizeBytes
);
const oversizedFiles =
rawMax && rawMax > 0
? files.filter((file) => file.size > rawMax * 1024 * 1024)
: [];
const validFiles =
rawMax && rawMax > 0
? files.filter((file) => file.size <= rawMax * 1024 * 1024)
: files;
if (oversizedFiles.length > 0) {
const skippedNames = oversizedFiles.map((file) => file.name).join(", ");
toast.warning(
`Skipped ${oversizedFiles.length} oversized file(s) (>${maxUploadSizeMb} MB): ${skippedNames}`
`Skipped ${oversizedFiles.length} oversized file(s) (>${rawMax} MB): ${skippedNames}`
);
}

View File

@@ -457,6 +457,7 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
<div
tabIndex={-1}
ref={closeButtonRef as React.RefObject<HTMLDivElement>}
className="outline-none"
>
<DialogPrimitive.Close asChild>
<Button

View File

@@ -12,6 +12,7 @@ import {
SvgArrowRightCircle,
SvgCheckSquare,
SvgSettings,
SvgUnplug,
} from "@opal/icons";
const containerClasses = {
@@ -35,6 +36,7 @@ export interface SelectProps
onSelect?: () => void;
onDeselect?: () => void;
onEdit?: () => void;
onDisconnect?: () => void;
// Labels (customizable)
connectLabel?: string;
@@ -59,6 +61,7 @@ export default function Select({
onSelect,
onDeselect,
onEdit,
onDisconnect,
connectLabel = "Connect",
selectLabel = "Set as Default",
selectedLabel = "Current Default",
@@ -68,7 +71,7 @@ export default function Select({
disabled,
...rest
}: SelectProps) {
const sizeClass = medium ? "h-[3.75rem]" : "h-[4.25rem]";
const sizeClass = medium ? "h-[3.75rem]" : "min-h-[3.75rem] max-h-[5.25rem]";
const containerClass = containerClasses[status];
const [isHovered, setIsHovered] = useState(false);
@@ -121,7 +124,7 @@ export default function Select({
</div>
{/* Right section - Actions */}
<div className="flex items-center justify-end gap-1">
<div className="flex flex-col h-full items-end justify-between gap-1">
{/* Disconnected: Show Connect button */}
{isDisconnected && (
<Disabled disabled={disabled || !onConnect}>
@@ -149,18 +152,32 @@ export default function Select({
{selectLabel}
</SelectButton>
</Disabled>
{onEdit && (
<Disabled disabled={disabled}>
<Button
icon={SvgSettings}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={noProp(onEdit)}
aria-label={`Edit ${title}`}
/>
</Disabled>
)}
<div className="flex px-1 gap-1">
{onDisconnect && (
<Disabled disabled={disabled}>
<Button
icon={SvgUnplug}
tooltip="Disconnect"
prominence="tertiary"
size="sm"
onClick={noProp(onDisconnect)}
aria-label={`Disconnect ${title}`}
/>
</Disabled>
)}
{onEdit && (
<Disabled disabled={disabled}>
<Button
icon={SvgSettings}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={noProp(onEdit)}
aria-label={`Edit ${title}`}
/>
</Disabled>
)}
</div>
</>
)}
@@ -177,18 +194,32 @@ export default function Select({
{selectedLabel}
</SelectButton>
</Disabled>
{onEdit && (
<Disabled disabled={disabled}>
<Button
icon={SvgSettings}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={noProp(onEdit)}
aria-label={`Edit ${title}`}
/>
</Disabled>
)}
<div className="flex px-1 gap-1">
{onDisconnect && (
<Disabled disabled={disabled}>
<Button
icon={SvgUnplug}
tooltip="Disconnect"
prominence="tertiary"
size="sm"
onClick={noProp(onDisconnect)}
aria-label={`Disconnect ${title}`}
/>
</Disabled>
)}
{onEdit && (
<Disabled disabled={disabled}>
<Button
icon={SvgSettings}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={noProp(onEdit)}
aria-label={`Edit ${title}`}
/>
</Disabled>
)}
</div>
</>
)}
</div>

View File

@@ -12,7 +12,7 @@ export interface FieldContextType {
export type FormFieldRootProps = React.HTMLAttributes<HTMLDivElement> & {
name?: string;
state: FormFieldState;
state?: FormFieldState;
required?: boolean;
id?: string;
};

View File

@@ -25,6 +25,7 @@ import {
SvgFold,
SvgExternalLink,
SvgAlertCircle,
SvgRefreshCw,
} from "@opal/icons";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { Content } from "@opal/layouts";
@@ -54,6 +55,7 @@ import * as ExpandableCard from "@/layouts/expandable-card-layouts";
import * as ActionsLayouts from "@/layouts/actions-layouts";
import { getActionIcon } from "@/lib/tools/mcpUtils";
import { Disabled } from "@opal/core";
import IconButton from "@/refresh-components/buttons/IconButton";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import useFilter from "@/hooks/useFilter";
import { MCPServer } from "@/lib/tools/interfaces";
@@ -81,6 +83,10 @@ interface ChatPreferencesFormValues {
maximum_chat_retention_days: string;
anonymous_user_enabled: boolean;
disable_default_assistant: boolean;
// File limits
user_file_max_upload_size_mb: string;
file_token_count_threshold_k: string;
}
interface MCPServerCardTool {
@@ -185,6 +191,173 @@ function MCPServerCard({
);
}
type FileLimitFieldName =
| "user_file_max_upload_size_mb"
| "file_token_count_threshold_k";
interface NumericLimitFieldProps {
name: FileLimitFieldName;
defaultValue: string;
saveSettings: (updates: Partial<Settings>) => Promise<void>;
maxValue?: number;
allowZero?: boolean;
}
function NumericLimitField({
name,
defaultValue,
saveSettings,
maxValue,
allowZero = false,
}: NumericLimitFieldProps) {
const { values, setFieldValue } =
useFormikContext<ChatPreferencesFormValues>();
const initialValue = useRef(values[name]);
const restoringRef = useRef(false);
const value = values[name];
const parsed = parseInt(value, 10);
const isOverMax =
maxValue !== undefined && !isNaN(parsed) && parsed > maxValue;
const handleRestore = () => {
restoringRef.current = true;
initialValue.current = defaultValue;
void setFieldValue(name, defaultValue);
void saveSettings({ [name]: parseInt(defaultValue, 10) });
};
const handleBlur = () => {
// The restore button triggers a blur — skip since handleRestore already saved.
if (restoringRef.current) {
restoringRef.current = false;
return;
}
const parsed = parseInt(value, 10);
const isValid = !isNaN(parsed) && (allowZero ? parsed >= 0 : parsed > 0);
// Revert invalid input (empty, NaN, negative).
if (!isValid) {
if (allowZero) {
// Empty/invalid means "no limit" — persist 0 and clear the field.
void setFieldValue(name, "");
void saveSettings({ [name]: 0 });
initialValue.current = "";
} else {
void setFieldValue(name, initialValue.current);
}
return;
}
// Block save when the value exceeds the hard ceiling.
if (maxValue !== undefined && parsed > maxValue) {
return;
}
// For allowZero fields, 0 means "no limit" — clear the display
// so the "No limit" placeholder is visible, but still persist 0.
if (allowZero && parsed === 0) {
void setFieldValue(name, "");
if (initialValue.current !== "") {
void saveSettings({ [name]: 0 });
initialValue.current = "";
}
return;
}
const normalizedDisplay = String(parsed);
// Update the display to the canonical form (e.g. strip leading zeros).
if (value !== normalizedDisplay) {
void setFieldValue(name, normalizedDisplay);
}
// Persist only when the value actually changed.
if (normalizedDisplay !== initialValue.current) {
void saveSettings({ [name]: parsed });
initialValue.current = normalizedDisplay;
}
};
return (
<div className="group w-full">
<InputTypeInField
name={name}
inputMode="numeric"
showClearButton={false}
pattern="[0-9]*"
placeholder={allowZero ? "No limit" : `Default: ${defaultValue}`}
variant={isOverMax ? "error" : undefined}
rightSection={
(value || "") !== defaultValue ? (
<div className="opacity-0 group-hover:opacity-100 group-focus-within:opacity-100 transition-opacity">
<IconButton
icon={SvgRefreshCw}
tooltip="Restore default"
internal
type="button"
onClick={handleRestore}
/>
</div>
) : undefined
}
onBlur={handleBlur}
/>
</div>
);
}
interface FileSizeLimitFieldsProps {
saveSettings: (updates: Partial<Settings>) => Promise<void>;
defaultUploadSizeMb: string;
defaultTokenThresholdK: string;
maxAllowedUploadSizeMb?: number;
}
function FileSizeLimitFields({
saveSettings,
defaultUploadSizeMb,
defaultTokenThresholdK,
maxAllowedUploadSizeMb,
}: FileSizeLimitFieldsProps) {
return (
<div className="flex gap-4 w-full items-start">
<div className="flex-1">
<InputLayouts.Vertical
title="File Size Limit (MB)"
subDescription={
maxAllowedUploadSizeMb
? `Max: ${maxAllowedUploadSizeMb} MB`
: undefined
}
nonInteractive
>
<NumericLimitField
name="user_file_max_upload_size_mb"
defaultValue={defaultUploadSizeMb}
saveSettings={saveSettings}
maxValue={maxAllowedUploadSizeMb}
/>
</InputLayouts.Vertical>
</div>
<div className="flex-1">
<InputLayouts.Vertical
title="File Token Limit (thousand tokens)"
nonInteractive
>
<NumericLimitField
name="file_token_count_threshold_k"
defaultValue={defaultTokenThresholdK}
saveSettings={saveSettings}
allowZero
/>
</InputLayouts.Vertical>
</div>
</div>
);
}
/**
* Inner form component that uses useFormikContext to access values
* and create save handlers for settings fields.
@@ -201,6 +374,7 @@ function ChatPreferencesForm() {
// Tools availability
const { tools: availableTools } = useAvailableTools();
const vectorDbEnabled = useVectorDbEnabled();
const searchTool = availableTools.find(
(t) => t.in_code_tool_id === SEARCH_TOOL_ID
);
@@ -723,6 +897,28 @@ function ChatPreferencesForm() {
</InputLayouts.Horizontal>
</Card>
<Card>
<InputLayouts.Vertical
title="File Attachment Size Limit"
description="Files attached in chats and projects must fit within both limits to be accepted. Larger files increase latency, memory usage, and token costs."
>
<FileSizeLimitFields
saveSettings={saveSettings}
defaultUploadSizeMb={
settings?.settings.default_user_file_max_upload_size_mb?.toString() ??
"100"
}
defaultTokenThresholdK={
settings?.settings.default_file_token_count_threshold_k?.toString() ??
"200"
}
maxAllowedUploadSizeMb={
settings?.settings.max_allowed_upload_size_mb
}
/>
</InputLayouts.Vertical>
</Card>
<Card>
<InputLayouts.Horizontal
title="Allow Anonymous Users"
@@ -862,6 +1058,21 @@ export default function ChatPreferencesPage() {
anonymous_user_enabled: settings.settings.anonymous_user_enabled ?? false,
disable_default_assistant:
settings.settings.disable_default_assistant ?? false,
// File limits — for upload size: 0/null means "use default";
// for token threshold: null means "use default", 0 means "no limit".
user_file_max_upload_size_mb:
(settings.settings.user_file_max_upload_size_mb ?? 0) <= 0
? settings.settings.default_user_file_max_upload_size_mb?.toString() ??
"100"
: settings.settings.user_file_max_upload_size_mb!.toString(),
file_token_count_threshold_k:
settings.settings.file_token_count_threshold_k == null
? settings.settings.default_file_token_count_threshold_k?.toString() ??
"200"
: settings.settings.file_token_count_threshold_k === 0
? ""
: settings.settings.file_token_count_threshold_k.toString(),
};
return (

View File

@@ -0,0 +1,559 @@
"use client";
import { useState } from "react";
import { Button, Text } from "@opal/components";
import { Disabled } from "@opal/core";
import {
SvgCheckCircle,
SvgHookNodes,
SvgLoader,
SvgRevert,
} from "@opal/icons";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
import { FormField } from "@/refresh-components/form/FormField";
import { Section } from "@/layouts/general-layouts";
import { ContentAction } from "@opal/layouts";
import { toast } from "@/hooks/useToast";
import {
createHook,
updateHook,
HookAuthError,
HookTimeoutError,
HookConnectError,
} from "@/refresh-pages/admin/HooksPage/svc";
import type {
HookFailStrategy,
HookFormState,
HookPointMeta,
HookResponse,
HookUpdateRequest,
} from "@/refresh-pages/admin/HooksPage/interfaces";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface HookFormModalProps {
open: boolean;
onOpenChange: (open: boolean) => void;
/** When provided, the modal is in edit mode for this hook. */
hook?: HookResponse;
/** When provided (create mode), the hook point is pre-selected and locked. */
spec?: HookPointMeta;
onSuccess: (hook: HookResponse) => void;
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
function buildInitialState(
hook: HookResponse | undefined,
spec: HookPointMeta | undefined
): HookFormState {
if (hook) {
return {
name: hook.name,
endpoint_url: hook.endpoint_url ?? "",
api_key: "",
fail_strategy: hook.fail_strategy,
timeout_seconds: String(hook.timeout_seconds),
};
}
return {
name: "",
endpoint_url: "",
api_key: "",
fail_strategy: spec?.default_fail_strategy ?? "hard",
timeout_seconds: spec ? String(spec.default_timeout_seconds) : "30",
};
}
const SOFT_DESCRIPTION =
"If the endpoint returns an error, Onyx logs it and continues the pipeline as normal, ignoring the hook result.";
const MAX_TIMEOUT_SECONDS = 600;
// ---------------------------------------------------------------------------
// Component
// ---------------------------------------------------------------------------
export default function HookFormModal({
open,
onOpenChange,
hook,
spec,
onSuccess,
}: HookFormModalProps) {
const isEdit = !!hook;
const [form, setForm] = useState<HookFormState>(() =>
buildInitialState(hook, spec)
);
const [isSubmitting, setIsSubmitting] = useState(false);
const [isConnected, setIsConnected] = useState(false);
// Tracks whether the user explicitly cleared the API key field in edit mode.
// - false + empty field → key unchanged (omitted from PATCH)
// - true + empty field → key cleared (api_key: null sent to backend)
// - false + non-empty → new key provided (new value sent to backend)
const [apiKeyCleared, setApiKeyCleared] = useState(false);
const [touched, setTouched] = useState({
name: false,
endpoint_url: false,
api_key: false,
});
const [apiKeyServerError, setApiKeyServerError] = useState(false);
const [endpointServerError, setEndpointServerError] = useState<string | null>(
null
);
const [timeoutServerError, setTimeoutServerError] = useState(false);
function touch(key: keyof typeof touched) {
setTouched((prev) => ({ ...prev, [key]: true }));
}
function handleOpenChange(next: boolean) {
if (!next) {
if (isSubmitting) return;
setTimeout(() => {
setForm(buildInitialState(hook, spec));
setIsConnected(false);
setApiKeyCleared(false);
setTouched({ name: false, endpoint_url: false, api_key: false });
setApiKeyServerError(false);
setEndpointServerError(null);
setTimeoutServerError(false);
}, 200);
}
onOpenChange(next);
}
function set<K extends keyof HookFormState>(key: K, value: HookFormState[K]) {
setForm((prev) => ({ ...prev, [key]: value }));
}
const timeoutNum = parseFloat(form.timeout_seconds);
const isTimeoutValid =
!isNaN(timeoutNum) && timeoutNum > 0 && timeoutNum <= MAX_TIMEOUT_SECONDS;
const isValid =
form.name.trim().length > 0 &&
form.endpoint_url.trim().length > 0 &&
isTimeoutValid &&
(isEdit || form.api_key.trim().length > 0);
const nameError = touched.name && !form.name.trim();
const endpointEmptyError = touched.endpoint_url && !form.endpoint_url.trim();
const endpointFieldError = endpointEmptyError
? "Endpoint URL cannot be empty."
: endpointServerError ?? undefined;
const apiKeyEmptyError = !isEdit && touched.api_key && !form.api_key.trim();
const apiKeyFieldError = apiKeyEmptyError
? "API key cannot be empty."
: apiKeyServerError
? "Invalid API key."
: undefined;
function handleTimeoutBlur() {
if (!isTimeoutValid) {
const fallback = hook?.timeout_seconds ?? spec?.default_timeout_seconds;
if (fallback !== undefined) {
set("timeout_seconds", String(fallback));
if (timeoutServerError) setTimeoutServerError(false);
}
}
}
const hasChanges =
isEdit && hook
? form.name !== hook.name ||
form.endpoint_url !== (hook.endpoint_url ?? "") ||
form.fail_strategy !== hook.fail_strategy ||
timeoutNum !== hook.timeout_seconds ||
form.api_key.trim().length > 0 ||
apiKeyCleared
: true;
async function handleSubmit() {
if (!isValid) return;
setIsSubmitting(true);
try {
let result: HookResponse;
if (isEdit && hook) {
const req: HookUpdateRequest = {};
if (form.name !== hook.name) req.name = form.name;
if (form.endpoint_url !== (hook.endpoint_url ?? ""))
req.endpoint_url = form.endpoint_url;
if (form.fail_strategy !== hook.fail_strategy)
req.fail_strategy = form.fail_strategy;
if (timeoutNum !== hook.timeout_seconds)
req.timeout_seconds = timeoutNum;
if (form.api_key.trim().length > 0) {
req.api_key = form.api_key;
} else if (apiKeyCleared) {
req.api_key = null;
}
if (Object.keys(req).length === 0) {
setIsSubmitting(false);
handleOpenChange(false);
return;
}
result = await updateHook(hook.id, req);
} else {
if (!spec) {
toast.error("No hook point specified.");
setIsSubmitting(false);
return;
}
result = await createHook({
name: form.name,
hook_point: spec.hook_point,
endpoint_url: form.endpoint_url,
...(form.api_key ? { api_key: form.api_key } : {}),
fail_strategy: form.fail_strategy,
timeout_seconds: timeoutNum,
});
}
toast.success(isEdit ? "Hook updated." : "Hook created.");
onSuccess(result);
if (!isEdit) {
setIsConnected(true);
await new Promise((resolve) => setTimeout(resolve, 500));
}
setIsSubmitting(false);
handleOpenChange(false);
} catch (err) {
if (err instanceof HookAuthError) {
setApiKeyServerError(true);
} else if (err instanceof HookTimeoutError) {
setTimeoutServerError(true);
} else if (err instanceof HookConnectError) {
setEndpointServerError(err.message || "Could not connect to endpoint.");
} else {
toast.error(
err instanceof Error ? err.message : "Something went wrong."
);
}
setIsSubmitting(false);
}
}
const hookPointDisplayName =
spec?.display_name ?? spec?.hook_point ?? hook?.hook_point ?? "";
const hookPointDescription = spec?.description;
const docsUrl = spec?.docs_url;
const failStrategyDescription =
form.fail_strategy === "soft"
? SOFT_DESCRIPTION
: spec?.fail_hard_description;
return (
<Modal open={open} onOpenChange={handleOpenChange}>
<Modal.Content width="md" height="fit">
<Modal.Header
icon={SvgHookNodes}
title={isEdit ? "Manage Hook Extension" : "Set Up Hook Extension"}
description={
isEdit
? undefined
: "Connect an external API endpoint to extend the hook point."
}
onClose={() => handleOpenChange(false)}
/>
<Modal.Body>
{/* Hook point section header */}
<ContentAction
sizePreset="main-ui"
variant="section"
paddingVariant="fit"
title={hookPointDisplayName}
description={hookPointDescription}
rightChildren={
<Section
flexDirection="column"
alignItems="end"
width="fit"
height="fit"
gap={0.25}
>
<div className="flex items-center gap-1">
<SvgHookNodes
style={{ width: "1rem", height: "1rem" }}
className="text-text-03 shrink-0"
/>
<Text font="secondary-body" color="text-03">
Hook Point
</Text>
</div>
{docsUrl && (
<a
href={docsUrl}
target="_blank"
rel="noopener noreferrer"
className="underline"
>
<Text font="secondary-body" color="text-03">
Documentation
</Text>
</a>
)}
</Section>
}
/>
<FormField className="w-full" state={nameError ? "error" : "idle"}>
<FormField.Label>Display Name</FormField.Label>
<FormField.Control>
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
<InputTypeIn
value={form.name}
onChange={(e) => set("name", e.target.value)}
onBlur={() => touch("name")}
placeholder="Name your extension at this hook point"
variant={
isSubmitting ? "disabled" : nameError ? "error" : undefined
}
/>
</div>
</FormField.Control>
<FormField.Message
messages={{ error: "Display name cannot be empty." }}
/>
</FormField>
<FormField className="w-full">
<FormField.Label>Fail Strategy</FormField.Label>
<FormField.Control>
<InputSelect
value={form.fail_strategy}
onValueChange={(v) =>
set("fail_strategy", v as HookFailStrategy)
}
disabled={isSubmitting}
>
<InputSelect.Trigger placeholder="Select strategy" />
<InputSelect.Content>
<InputSelect.Item value="soft">
Log Error and Continue
{spec?.default_fail_strategy === "soft" && (
<>
{" "}
<Text color="text-03">(Default)</Text>
</>
)}
</InputSelect.Item>
<InputSelect.Item value="hard">
Block Pipeline on Failure
{spec?.default_fail_strategy === "hard" && (
<>
{" "}
<Text color="text-03">(Default)</Text>
</>
)}
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</FormField.Control>
<FormField.Description>
{failStrategyDescription}
</FormField.Description>
</FormField>
<FormField
className="w-full"
state={timeoutServerError ? "error" : "idle"}
>
<FormField.Label>
Timeout{" "}
<Text font="main-ui-action" color="text-03">
(seconds)
</Text>
</FormField.Label>
<FormField.Control>
<div className="[&_input]:!font-main-ui-mono [&_input::placeholder]:!font-main-ui-mono [&_input]:![appearance:textfield] [&_input::-webkit-outer-spin-button]:!appearance-none [&_input::-webkit-inner-spin-button]:!appearance-none w-full">
<InputTypeIn
type="number"
value={form.timeout_seconds}
onChange={(e) => {
set("timeout_seconds", e.target.value);
if (timeoutServerError) setTimeoutServerError(false);
}}
onBlur={handleTimeoutBlur}
placeholder={
spec ? String(spec.default_timeout_seconds) : undefined
}
variant={
isSubmitting
? "disabled"
: timeoutServerError
? "error"
: undefined
}
showClearButton={false}
rightSection={
spec?.default_timeout_seconds !== undefined &&
form.timeout_seconds !==
String(spec.default_timeout_seconds) ? (
<Button
prominence="tertiary"
size="xs"
icon={SvgRevert}
tooltip="Revert to Default"
onClick={() =>
set(
"timeout_seconds",
String(spec.default_timeout_seconds)
)
}
disabled={isSubmitting}
/>
) : undefined
}
/>
</div>
</FormField.Control>
{!timeoutServerError && (
<FormField.Description>
Maximum time Onyx will wait for the endpoint to respond before
applying the fail strategy. Must be greater than 0 and at most{" "}
{MAX_TIMEOUT_SECONDS} seconds.
</FormField.Description>
)}
<FormField.Message
messages={{
error: "Connection timed out. Try increasing the timeout.",
}}
/>
</FormField>
<FormField
className="w-full"
state={endpointFieldError ? "error" : "idle"}
>
<FormField.Label>External API Endpoint URL</FormField.Label>
<FormField.Control>
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
<InputTypeIn
value={form.endpoint_url}
onChange={(e) => {
set("endpoint_url", e.target.value);
if (endpointServerError) setEndpointServerError(null);
}}
onBlur={() => touch("endpoint_url")}
placeholder="https://your-api-endpoint.com"
variant={
isSubmitting
? "disabled"
: endpointFieldError
? "error"
: undefined
}
/>
</div>
</FormField.Control>
{!endpointFieldError && (
<FormField.Description>
Only connect to servers you trust. You are responsible for
actions taken and data shared with this connection.
</FormField.Description>
)}
<FormField.Message messages={{ error: endpointFieldError }} />
</FormField>
<FormField
className="w-full"
state={apiKeyFieldError ? "error" : "idle"}
>
<FormField.Label>API Key</FormField.Label>
<FormField.Control>
<PasswordInputTypeIn
value={form.api_key}
onChange={(e) => {
set("api_key", e.target.value);
if (apiKeyServerError) setApiKeyServerError(false);
if (isEdit) {
setApiKeyCleared(
e.target.value === "" && !!hook?.api_key_masked
);
}
}}
onBlur={() => touch("api_key")}
placeholder={
isEdit
? hook?.api_key_masked ?? "Leave blank to keep current key"
: undefined
}
disabled={isSubmitting}
error={!!apiKeyFieldError}
/>
</FormField.Control>
{!apiKeyFieldError && (
<FormField.Description>
Onyx will use this key to authenticate with your API endpoint.
</FormField.Description>
)}
<FormField.Message messages={{ error: apiKeyFieldError }} />
</FormField>
{!isEdit && (isSubmitting || isConnected) && (
<Section
flexDirection="row"
alignItems="center"
justifyContent="start"
height="fit"
gap={1}
className="px-0.5"
>
<div className="p-0.5 shrink-0">
{isConnected ? (
<SvgCheckCircle
size={16}
className="text-status-success-05"
/>
) : (
<SvgLoader size={16} className="animate-spin text-text-03" />
)}
</div>
<Text font="secondary-body" color="text-03">
{isConnected ? "Connection valid." : "Verifying connection…"}
</Text>
</Section>
)}
</Modal.Body>
<Modal.Footer>
<BasicModalFooter
cancel={
<Disabled disabled={isSubmitting}>
<Button
prominence="secondary"
onClick={() => handleOpenChange(false)}
>
Cancel
</Button>
</Disabled>
}
submit={
<Disabled disabled={isSubmitting || !isValid || !hasChanges}>
<Button
onClick={handleSubmit}
icon={
isSubmitting && !isEdit
? () => <SvgLoader size={16} className="animate-spin" />
: undefined
}
>
{isEdit ? "Save Changes" : "Connect"}
</Button>
</Disabled>
}
/>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}

View File

@@ -29,6 +29,14 @@ export interface HookResponse {
updated_at: string;
}
export interface HookFormState {
name: string;
endpoint_url: string;
api_key: string;
fail_strategy: HookFailStrategy;
timeout_seconds: string;
}
export interface HookCreateRequest {
name: string;
hook_point: HookPoint;

View File

@@ -5,15 +5,27 @@ import {
HookValidateResponse,
} from "@/refresh-pages/admin/HooksPage/interfaces";
async function parseErrorDetail(
res: Response,
fallback: string
): Promise<string> {
export class HookAuthError extends Error {}
export class HookTimeoutError extends Error {}
export class HookConnectError extends Error {}
async function parseError(res: Response, fallback: string): Promise<Error> {
try {
const body = await res.json();
return body?.detail ?? fallback;
if (body?.error_code === "CREDENTIAL_INVALID") {
return new HookAuthError(body?.detail ?? "Invalid API key.");
}
if (body?.error_code === "GATEWAY_TIMEOUT") {
return new HookTimeoutError(body?.detail ?? "Connection timed out.");
}
if (body?.error_code === "BAD_GATEWAY") {
return new HookConnectError(
body?.detail ?? "Could not connect to endpoint."
);
}
return new Error(body?.detail ?? fallback);
} catch {
return fallback;
return new Error(fallback);
}
}
@@ -26,7 +38,7 @@ export async function createHook(
body: JSON.stringify(req),
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to create hook"));
throw await parseError(res, "Failed to create hook");
}
return res.json();
}
@@ -41,7 +53,7 @@ export async function updateHook(
body: JSON.stringify(req),
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to update hook"));
throw await parseError(res, "Failed to update hook");
}
return res.json();
}
@@ -49,7 +61,7 @@ export async function updateHook(
export async function deleteHook(id: number): Promise<void> {
const res = await fetch(`/api/admin/hooks/${id}`, { method: "DELETE" });
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to delete hook"));
throw await parseError(res, "Failed to delete hook");
}
}
@@ -58,7 +70,7 @@ export async function activateHook(id: number): Promise<HookResponse> {
method: "POST",
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to activate hook"));
throw await parseError(res, "Failed to activate hook");
}
return res.json();
}
@@ -68,7 +80,7 @@ export async function deactivateHook(id: number): Promise<HookResponse> {
method: "POST",
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to deactivate hook"));
throw await parseError(res, "Failed to deactivate hook");
}
return res.json();
}
@@ -78,7 +90,7 @@ export async function validateHook(id: number): Promise<HookValidateResponse> {
method: "POST",
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to validate hook"));
throw await parseError(res, "Failed to validate hook");
}
return res.json();
}

View File

@@ -45,6 +45,7 @@ import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
import { Section } from "@/layouts/general-layouts";
const route = ADMIN_ROUTES.LLM_MODELS;
@@ -65,6 +66,7 @@ const PROVIDER_DISPLAY_ORDER: string[] = [
"ollama_chat",
"openrouter",
"lm_studio",
"bifrost",
];
const PROVIDER_MODAL_MAP: Record<
@@ -138,6 +140,13 @@ const PROVIDER_MODAL_MAP: Record<
onOpenChange={onOpenChange}
/>
),
bifrost: (d, open, onOpenChange) => (
<BifrostModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
};
// ============================================================================

View File

@@ -1,6 +1,6 @@
"use client";
import { useMemo, useState } from "react";
import { useEffect, useMemo, useState } from "react";
import {
AzureIcon,
ElevenLabsIcon,
@@ -19,12 +19,18 @@ import {
import {
activateVoiceProvider,
deactivateVoiceProvider,
deleteVoiceProvider,
} from "@/lib/admin/voice/svc";
import { ThreeDotsLoader } from "@/components/Loading";
import { toast } from "@/hooks/useToast";
import { Callout } from "@/components/ui/callout";
import { Content } from "@opal/layouts";
import { SvgMicrophone } from "@opal/icons";
import { SvgMicrophone, SvgSlash, SvgUnplug } from "@opal/icons";
import { Button as OpalButton } from "@opal/components";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import { Section } from "@/layouts/general-layouts";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import VoiceProviderSetupModal from "@/app/admin/configuration/voice/VoiceProviderSetupModal";
interface ModelDetails {
@@ -129,10 +135,152 @@ function getProviderIcon(
type ProviderMode = "stt" | "tts";
function getProviderLabel(providerType: string): string {
switch (providerType) {
case "openai":
return "OpenAI";
case "azure":
return "Azure";
case "elevenlabs":
return "ElevenLabs";
default:
return providerType;
}
}
const NO_DEFAULT_VALUE = "__none__";
const route = ADMIN_ROUTES.VOICE;
const pageDescription =
"Configure speech-to-text and text-to-speech providers for voice input and spoken responses.";
interface VoiceDisconnectModalProps {
disconnectTarget: {
providerId: number;
providerLabel: string;
providerType: string;
};
providers: VoiceProviderView[];
replacementProviderId: string | null;
onReplacementChange: (id: string | null) => void;
onClose: () => void;
onDisconnect: () => void;
}
function VoiceDisconnectModal({
disconnectTarget,
providers,
replacementProviderId,
onReplacementChange,
onClose,
onDisconnect,
}: VoiceDisconnectModalProps) {
const targetProvider = providers.find(
(p) => p.id === disconnectTarget.providerId
);
const isActive =
(targetProvider?.is_default_stt ?? false) ||
(targetProvider?.is_default_tts ?? false);
// Find other configured providers that could serve as replacements
const replacementOptions = providers.filter(
(p) => p.id !== disconnectTarget.providerId && p.has_api_key
);
const needsReplacement = isActive;
const hasReplacements = replacementOptions.length > 0;
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && hasReplacements && !replacementProviderId) {
const first = replacementOptions[0];
if (first) onReplacementChange(String(first.id));
}
}, []); // eslint-disable-line react-hooks/exhaustive-deps
return (
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectTarget.providerLabel}`}
description="Voice models"
onClose={onClose}
submit={
<OpalButton
variant="danger"
onClick={onDisconnect}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</OpalButton>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" text03>
<b>{disconnectTarget.providerLabel}</b> models will no longer be
used for speech-to-text or text-to-speech, and it will no longer
be your default. Session history will be preserved.
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" text04>
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => onReplacementChange(v)}
>
<InputSelect.Trigger placeholder="Select a replacement provider" />
<InputSelect.Content>
{replacementOptions.map((p) => (
<InputSelect.Item
key={p.id}
value={String(p.id)}
icon={getProviderIcon(p.provider_type)}
>
{getProviderLabel(p.provider_type)}
</InputSelect.Item>
))}
<InputSelect.Separator />
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
<span>
<b>No Default</b>
<span className="text-text-03"> (Disable Voice)</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" text03>
<b>{disconnectTarget.providerLabel}</b> models will no longer be
used for speech-to-text or text-to-speech, and it will no longer
be your default.
</Text>
<Text as="p" text03>
Connect another provider to continue using voice.
</Text>
</>
)
) : (
<>
<Text as="p" text03>
<b>{disconnectTarget.providerLabel}</b> models will no longer be
available for voice.
</Text>
<Text as="p" text03>
Session history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
);
}
export default function VoiceConfigurationPage() {
const [modalOpen, setModalOpen] = useState(false);
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
@@ -146,6 +294,14 @@ export default function VoiceConfigurationPage() {
const [ttsActivationError, setTTSActivationError] = useState<string | null>(
null
);
const [disconnectTarget, setDisconnectTarget] = useState<{
providerId: number;
providerLabel: string;
providerType: string;
} | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const { providers, error, isLoading, refresh: mutate } = useVoiceProviders();
@@ -237,6 +393,65 @@ export default function VoiceConfigurationPage() {
handleModalClose();
};
const handleDisconnect = async () => {
if (!disconnectTarget) return;
try {
const targetProvider = providers.find(
(p) => p.id === disconnectTarget.providerId
);
// If a replacement was selected (not "No Default"), activate it for each
// mode the disconnected provider was default for
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
const repId = Number(replacementProviderId);
if (targetProvider?.is_default_stt) {
const resp = await activateVoiceProvider(repId, "stt");
if (!resp.ok) {
const errorBody = await resp.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to activate replacement STT provider."
);
}
}
if (targetProvider?.is_default_tts) {
const resp = await activateVoiceProvider(repId, "tts");
if (!resp.ok) {
const errorBody = await resp.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to activate replacement TTS provider."
);
}
}
}
const response = await deleteVoiceProvider(disconnectTarget.providerId);
if (!response.ok) {
const errorBody = await response.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to disconnect provider."
);
}
await mutate();
toast.success(`${disconnectTarget.providerLabel} disconnected`);
} catch (err) {
console.error("Failed to disconnect voice provider:", err);
toast.error(
err instanceof Error ? err.message : "Unexpected error occurred."
);
} finally {
setDisconnectTarget(null);
setReplacementProviderId(null);
}
};
const isProviderConfigured = (provider?: VoiceProviderView): boolean => {
return !!provider?.has_api_key;
};
@@ -289,6 +504,16 @@ export default function VoiceConfigurationPage() {
onEdit={() => {
if (provider) handleEdit(provider, mode, model.id);
}}
onDisconnect={
status !== "disconnected" && provider
? () =>
setDisconnectTarget({
providerId: provider.id,
providerLabel: getProviderLabel(model.providerType),
providerType: model.providerType,
})
: undefined
}
/>
);
};
@@ -412,6 +637,20 @@ export default function VoiceConfigurationPage() {
</div>
</SettingsLayouts.Body>
{disconnectTarget && (
<VoiceDisconnectModal
disconnectTarget={disconnectTarget}
providers={providers}
replacementProviderId={replacementProviderId}
onReplacementChange={setReplacementProviderId}
onClose={() => {
setDisconnectTarget(null);
setReplacementProviderId(null);
}}
onDisconnect={() => void handleDisconnect()}
/>
)}
{modalOpen && selectedProvider && (
<VoiceProviderSetupModal
providerType={selectedProvider}

View File

@@ -0,0 +1,278 @@
"use client";
import { useState, useEffect } from "react";
import { markdown } from "@opal/utils";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import { fetchBifrostModels } from "@/app/admin/configuration/llm/utils";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
DisplayNameField,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
const BIFROST_PROVIDER_NAME = LLMProviderName.BIFROST;
const DEFAULT_API_BASE = "";
interface BifrostModalValues extends BaseLLMFormValues {
api_key: string;
api_base: string;
}
interface BifrostModalInternalsProps {
formikProps: FormikProps<BifrostModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function BifrostModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: BifrostModalInternalsProps) {
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const isFetchDisabled = !formikProps.values.api_base;
const handleFetchModels = async () => {
const { models, error } = await fetchBifrostModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key || undefined,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
setFetchedModels(models);
};
// Auto-fetch models on initial load when editing an existing provider
useEffect(() => {
if (existingLlmProvider && !isFetchDisabled) {
handleFetchModels().catch((err) => {
console.error("Failed to fetch Bifrost models:", err);
toast.error(
err instanceof Error ? err.message : "Failed to fetch models"
);
});
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
return (
<LLMConfigurationModalWrapper
providerEndpoint={LLMProviderName.BIFROST}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
>
<InputTypeInField
name="api_base"
placeholder="https://your-bifrost-gateway.com/v1"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<FieldWrapper>
<InputLayouts.Vertical
name="api_key"
title="API Key"
optional={true}
subDescription={markdown(
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
)}
>
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
</InputLayouts.Vertical>
</FieldWrapper>
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. anthropic/claude-sonnet-4-6" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
);
}
export default function BifrostModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
BIFROST_PROVIDER_NAME
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: BifrostModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: BIFROST_PROVIDER_NAME,
provider: BIFROST_PROVIDER_NAME,
api_key: "",
api_base: DEFAULT_API_BASE,
default_model_name: "",
} as BifrostModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_base: Yup.string().required("API Base URL is required"),
});
return (
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: BIFROST_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: BIFROST_PROVIDER_NAME,
values,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
{(formikProps) => (
<BifrostModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
);
}

View File

@@ -9,6 +9,7 @@ import CustomModal from "@/sections/modals/llmConfig/CustomModal";
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
function detectIfRealOpenAIProvider(provider: LLMProviderView) {
return (
@@ -56,6 +57,8 @@ export function getModalForExistingProvider(
return <LMStudioForm {...props} />;
case LLMProviderName.LITELLM_PROXY:
return <LiteLLMProxyModal {...props} />;
case LLMProviderName.BIFROST:
return <BifrostModal {...props} />;
default:
return <CustomModal {...props} />;
}

View File

@@ -0,0 +1,246 @@
import { test, expect, Page, Locator } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
const IMAGE_GENERATION_URL = "/admin/configuration/image-generation";
const FAKE_CONNECTED_CONFIG = {
image_provider_id: "openai_dalle_3",
model_configuration_id: 100,
model_name: "dall-e-3",
llm_provider_id: 100,
llm_provider_name: "openai-dalle3",
is_default: false,
};
const FAKE_DEFAULT_CONFIG = {
image_provider_id: "openai_gpt_image_1",
model_configuration_id: 101,
model_name: "gpt-image-1",
llm_provider_id: 101,
llm_provider_name: "openai-gpt-image-1",
is_default: true,
};
function getProviderCard(page: Page, providerId: string): Locator {
return page.getByLabel(`image-gen-provider-${providerId}`, { exact: true });
}
function mainContainer(page: Page): Locator {
return page.locator("[data-main-container]");
}
/**
* Sets up route mocks so the page sees configured providers
* without needing real API keys.
*/
async function mockImageGenApis(
page: Page,
configs: (typeof FAKE_CONNECTED_CONFIG)[]
) {
await page.route("**/api/admin/image-generation/config", async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({ status: 200, json: configs });
} else {
await route.continue();
}
});
await page.route(
"**/api/admin/llm/provider?include_image_gen=true",
async (route) => {
await route.fulfill({ status: 200, json: { providers: [] } });
}
);
}
test.describe("Image Generation Provider Disconnect", () => {
test.beforeEach(async ({ page }) => {
await page.context().clearCookies();
await loginAs(page, "admin");
});
test("should disconnect a connected (non-default) provider", async ({
page,
}) => {
const configs = [{ ...FAKE_CONNECTED_CONFIG }, { ...FAKE_DEFAULT_CONFIG }];
await mockImageGenApis(page, configs);
await page.goto(IMAGE_GENERATION_URL);
await page.waitForSelector("text=Image Generation Model", {
timeout: 20000,
});
const card = getProviderCard(page, "openai_dalle_3");
await card.waitFor({ state: "visible", timeout: 10000 });
await expectElementScreenshot(mainContainer(page), {
name: "image-gen-disconnect-non-default-before",
});
// Verify disconnect button exists and is enabled
const disconnectButton = card.getByRole("button", {
name: "Disconnect DALL-E 3",
});
await expect(disconnectButton).toBeVisible();
await expect(disconnectButton).toBeEnabled();
// Mock the DELETE to succeed and update the config list
await page.route(
"**/api/admin/image-generation/config/openai_dalle_3",
async (route) => {
if (route.request().method() === "DELETE") {
// Update the GET mock to return only the default config
await page.unroute("**/api/admin/image-generation/config");
await page.route(
"**/api/admin/image-generation/config",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({
status: 200,
json: [{ ...FAKE_DEFAULT_CONFIG }],
});
} else {
await route.continue();
}
}
);
await route.fulfill({ status: 200, json: {} });
} else {
await route.continue();
}
}
);
// Click disconnect
await disconnectButton.click();
// Verify confirmation modal appears
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
await expect(confirmDialog).toContainText("Disconnect DALL-E 3");
await expectElementScreenshot(confirmDialog, {
name: "image-gen-disconnect-non-default-modal",
});
// Click Disconnect in the confirmation modal
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await confirmButton.click();
// Verify the card reverts to disconnected state (shows "Connect" button)
await expect(card.getByRole("button", { name: "Connect" })).toBeVisible({
timeout: 10000,
});
await expectElementScreenshot(mainContainer(page), {
name: "image-gen-disconnect-non-default-after",
});
});
test("should show replacement dropdown when disconnecting default provider with alternatives", async ({
page,
}) => {
const configs = [{ ...FAKE_CONNECTED_CONFIG }, { ...FAKE_DEFAULT_CONFIG }];
await mockImageGenApis(page, configs);
await page.goto(IMAGE_GENERATION_URL);
await page.waitForSelector("text=Image Generation Model", {
timeout: 20000,
});
const defaultCard = getProviderCard(page, "openai_gpt_image_1");
await defaultCard.waitFor({ state: "visible", timeout: 10000 });
// The disconnect button should be visible and enabled
const disconnectButton = defaultCard.getByRole("button", {
name: "Disconnect GPT Image 1",
});
await expect(disconnectButton).toBeVisible();
await expect(disconnectButton).toBeEnabled();
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
// Should show replacement dropdown since there's an alternative
await expect(
confirmDialog.getByText("Session history will be preserved")
).toBeVisible();
// Disconnect button should be enabled because first replacement is auto-selected
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await expect(confirmButton).toBeEnabled();
await expectElementScreenshot(confirmDialog, {
name: "image-gen-disconnect-default-with-alt-modal",
});
});
test("should show connect message when disconnecting default provider with no alternatives", async ({
page,
}) => {
// Only the default config — no other providers configured
await mockImageGenApis(page, [{ ...FAKE_DEFAULT_CONFIG }]);
await page.goto(IMAGE_GENERATION_URL);
await page.waitForSelector("text=Image Generation Model", {
timeout: 20000,
});
const defaultCard = getProviderCard(page, "openai_gpt_image_1");
await defaultCard.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = defaultCard.getByRole("button", {
name: "Disconnect GPT Image 1",
});
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
// Should show message about connecting another provider
await expect(
confirmDialog.getByText("Connect another provider")
).toBeVisible();
// Disconnect button should be enabled
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await expect(confirmButton).toBeEnabled();
await expectElementScreenshot(confirmDialog, {
name: "image-gen-disconnect-no-alt-modal",
});
});
test("should not show disconnect button for unconfigured providers", async ({
page,
}) => {
await mockImageGenApis(page, [{ ...FAKE_DEFAULT_CONFIG }]);
await page.goto(IMAGE_GENERATION_URL);
await page.waitForSelector("text=Image Generation Model", {
timeout: 20000,
});
// DALL-E 3 is not configured — should not have a disconnect button
const card = getProviderCard(page, "openai_dalle_3");
await card.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = card.getByRole("button", {
name: "Disconnect DALL-E 3",
});
await expect(disconnectButton).not.toBeVisible();
await expectElementScreenshot(mainContainer(page), {
name: "image-gen-disconnect-unconfigured",
});
});
});

View File

@@ -0,0 +1,317 @@
import { test, expect, Page, Locator } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
const VOICE_URL = "/admin/configuration/voice";
const FAKE_PROVIDERS = {
openai_active_stt: {
id: 1,
name: "openai",
provider_type: "openai",
is_default_stt: true,
is_default_tts: false,
stt_model: "whisper",
tts_model: null,
default_voice: null,
has_api_key: true,
target_uri: null,
},
openai_active_both: {
id: 1,
name: "openai",
provider_type: "openai",
is_default_stt: true,
is_default_tts: true,
stt_model: "whisper",
tts_model: "tts-1",
default_voice: "alloy",
has_api_key: true,
target_uri: null,
},
openai_connected: {
id: 1,
name: "openai",
provider_type: "openai",
is_default_stt: false,
is_default_tts: false,
stt_model: null,
tts_model: null,
default_voice: null,
has_api_key: true,
target_uri: null,
},
elevenlabs_connected: {
id: 2,
name: "elevenlabs",
provider_type: "elevenlabs",
is_default_stt: false,
is_default_tts: false,
stt_model: null,
tts_model: null,
default_voice: null,
has_api_key: true,
target_uri: null,
},
};
function findModelCard(page: Page, ariaLabel: string): Locator {
return page.getByLabel(ariaLabel, { exact: true });
}
function mainContainer(page: Page): Locator {
return page.locator("[data-main-container]");
}
async function mockVoiceApis(
page: Page,
providers: (typeof FAKE_PROVIDERS)[keyof typeof FAKE_PROVIDERS][]
) {
await page.route("**/api/admin/voice/providers", async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({ status: 200, json: providers });
} else {
await route.continue();
}
});
}
test.describe("Voice Provider Disconnect", () => {
test.beforeEach(async ({ page }) => {
await page.context().clearCookies();
await loginAs(page, "admin");
});
test("should disconnect a non-active provider and affect both STT and TTS cards", async ({
page,
}) => {
const providers = [
{ ...FAKE_PROVIDERS.openai_connected },
{ ...FAKE_PROVIDERS.elevenlabs_connected },
];
await mockVoiceApis(page, providers);
await page.goto(VOICE_URL);
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
const whisperCard = findModelCard(page, "voice-stt-whisper");
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
await expectElementScreenshot(mainContainer(page), {
name: "voice-disconnect-non-active-before",
});
const disconnectButton = whisperCard.getByRole("button", {
name: "Disconnect Whisper",
});
await expect(disconnectButton).toBeVisible();
await expect(disconnectButton).toBeEnabled();
// Mock DELETE to succeed and remove OpenAI from provider list
await page.route("**/api/admin/voice/providers/1", async (route) => {
if (route.request().method() === "DELETE") {
await page.unroute("**/api/admin/voice/providers");
await page.route("**/api/admin/voice/providers", async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({
status: 200,
json: [{ ...FAKE_PROVIDERS.elevenlabs_connected }],
});
} else {
await route.continue();
}
});
await route.fulfill({ status: 200, json: {} });
} else {
await route.continue();
}
});
await disconnectButton.click();
// Modal shows provider name, not model name
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
await expect(confirmDialog).toContainText("Disconnect OpenAI");
await expectElementScreenshot(confirmDialog, {
name: "voice-disconnect-non-active-modal",
});
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await confirmButton.click();
// Both STT and TTS cards for OpenAI revert to disconnected
await expect(
whisperCard.getByRole("button", { name: "Connect" })
).toBeVisible({ timeout: 10000 });
const tts1Card = findModelCard(page, "voice-tts-tts-1");
await expect(tts1Card.getByRole("button", { name: "Connect" })).toBeVisible(
{ timeout: 10000 }
);
await expectElementScreenshot(mainContainer(page), {
name: "voice-disconnect-non-active-after",
});
});
test("should show replacement dropdown when disconnecting active provider with alternatives", async ({
page,
}) => {
// OpenAI is active for STT, ElevenLabs is also configured
const providers = [
{ ...FAKE_PROVIDERS.openai_active_stt },
{ ...FAKE_PROVIDERS.elevenlabs_connected },
];
await mockVoiceApis(page, providers);
await page.goto(VOICE_URL);
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
const whisperCard = findModelCard(page, "voice-stt-whisper");
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
await expectElementScreenshot(mainContainer(page), {
name: "voice-disconnect-active-with-alt-before",
});
const disconnectButton = whisperCard.getByRole("button", {
name: "Disconnect Whisper",
});
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
await expect(confirmDialog).toContainText("Disconnect OpenAI");
// Should show replacement text and dropdown
await expect(
confirmDialog.getByText("Session history will be preserved")
).toBeVisible();
// Disconnect button should be enabled because first replacement is auto-selected
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await expect(confirmButton).toBeEnabled();
await expectElementScreenshot(confirmDialog, {
name: "voice-disconnect-active-with-alt-modal",
});
});
test("should show replacement when provider is default for both STT and TTS", async ({
page,
}) => {
// OpenAI is default for both modes, ElevenLabs also configured
const providers = [
{ ...FAKE_PROVIDERS.openai_active_both },
{ ...FAKE_PROVIDERS.elevenlabs_connected },
];
await mockVoiceApis(page, providers);
await page.goto(VOICE_URL);
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
const whisperCard = findModelCard(page, "voice-stt-whisper");
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
await expectElementScreenshot(mainContainer(page), {
name: "voice-disconnect-both-modes-before",
});
const disconnectButton = whisperCard.getByRole("button", {
name: "Disconnect Whisper",
});
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
await expect(confirmDialog).toContainText("Disconnect OpenAI");
// Should mention both modes
await expect(
confirmDialog.getByText("speech-to-text or text-to-speech")
).toBeVisible();
// Should show replacement dropdown
await expect(
confirmDialog.getByText("Session history will be preserved")
).toBeVisible();
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await expect(confirmButton).toBeEnabled();
await expectElementScreenshot(confirmDialog, {
name: "voice-disconnect-both-modes-modal",
});
});
test("should show connect message when disconnecting active provider with no alternatives", async ({
page,
}) => {
// Only OpenAI configured, active for STT — no other providers
const providers = [{ ...FAKE_PROVIDERS.openai_active_stt }];
await mockVoiceApis(page, providers);
await page.goto(VOICE_URL);
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
const whisperCard = findModelCard(page, "voice-stt-whisper");
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
await expectElementScreenshot(mainContainer(page), {
name: "voice-disconnect-no-alt-before",
});
const disconnectButton = whisperCard.getByRole("button", {
name: "Disconnect Whisper",
});
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
await expect(confirmDialog).toContainText("Disconnect OpenAI");
// Should show message about connecting another provider
await expect(
confirmDialog.getByText("Connect another provider")
).toBeVisible();
// Disconnect button should be enabled
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await expect(confirmButton).toBeEnabled();
await expectElementScreenshot(confirmDialog, {
name: "voice-disconnect-no-alt-modal",
});
});
test("should not show disconnect button for unconfigured provider", async ({
page,
}) => {
await mockVoiceApis(page, []);
await page.goto(VOICE_URL);
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
const whisperCard = findModelCard(page, "voice-stt-whisper");
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = whisperCard.getByRole("button", {
name: "Disconnect Whisper",
});
await expect(disconnectButton).not.toBeVisible();
await expectElementScreenshot(mainContainer(page), {
name: "voice-disconnect-unconfigured",
});
});
});

View File

@@ -0,0 +1,394 @@
import { test, expect, Page, Locator } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
const WEB_SEARCH_URL = "/admin/configuration/web-search";
const FAKE_SEARCH_PROVIDERS = {
exa: {
id: 1,
name: "Exa",
provider_type: "exa",
is_active: true,
config: null,
has_api_key: true,
},
brave: {
id: 2,
name: "Brave",
provider_type: "brave",
is_active: false,
config: null,
has_api_key: true,
},
};
const FAKE_CONTENT_PROVIDERS = {
firecrawl: {
id: 10,
name: "Firecrawl",
provider_type: "firecrawl",
is_active: true,
config: { base_url: "https://api.firecrawl.dev/v2/scrape" },
has_api_key: true,
},
exa: {
id: 11,
name: "Exa",
provider_type: "exa",
is_active: false,
config: null,
has_api_key: true,
},
};
function findProviderCard(page: Page, providerLabel: string): Locator {
return page
.locator("div.rounded-16")
.filter({ hasText: providerLabel })
.first();
}
function mainContainer(page: Page): Locator {
return page.locator("[data-main-container]");
}
async function mockWebSearchApis(
page: Page,
searchProviders: (typeof FAKE_SEARCH_PROVIDERS)[keyof typeof FAKE_SEARCH_PROVIDERS][],
contentProviders: (typeof FAKE_CONTENT_PROVIDERS)[keyof typeof FAKE_CONTENT_PROVIDERS][]
) {
await page.route(
"**/api/admin/web-search/search-providers",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({ status: 200, json: searchProviders });
} else {
await route.continue();
}
}
);
await page.route(
"**/api/admin/web-search/content-providers",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({ status: 200, json: contentProviders });
} else {
await route.continue();
}
}
);
}
test.describe("Web Search Provider Disconnect", () => {
test.beforeEach(async ({ page }) => {
await page.context().clearCookies();
await loginAs(page, "admin");
});
test.describe("Search Engine Providers", () => {
test("should disconnect a connected (non-active) search provider", async ({
page,
}) => {
const searchProviders = [
{ ...FAKE_SEARCH_PROVIDERS.exa },
{ ...FAKE_SEARCH_PROVIDERS.brave },
];
await mockWebSearchApis(page, searchProviders, []);
await page.goto(WEB_SEARCH_URL);
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
const braveCard = findProviderCard(page, "Brave");
await braveCard.waitFor({ state: "visible", timeout: 10000 });
await expectElementScreenshot(mainContainer(page), {
name: "web-search-disconnect-non-active-before",
});
const disconnectButton = braveCard.getByRole("button", {
name: "Disconnect Brave",
});
await expect(disconnectButton).toBeVisible();
await expect(disconnectButton).toBeEnabled();
// Mock the DELETE to succeed
await page.route(
"**/api/admin/web-search/search-providers/2",
async (route) => {
if (route.request().method() === "DELETE") {
await page.unroute("**/api/admin/web-search/search-providers");
await page.route(
"**/api/admin/web-search/search-providers",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({
status: 200,
json: [{ ...FAKE_SEARCH_PROVIDERS.exa }],
});
} else {
await route.continue();
}
}
);
await route.fulfill({ status: 200, json: {} });
} else {
await route.continue();
}
}
);
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
await expect(confirmDialog).toContainText("Disconnect Brave");
await expectElementScreenshot(confirmDialog, {
name: "web-search-disconnect-non-active-modal",
});
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await confirmButton.click();
await expect(
braveCard.getByRole("button", { name: "Connect" })
).toBeVisible({ timeout: 10000 });
await expectElementScreenshot(mainContainer(page), {
name: "web-search-disconnect-non-active-after",
});
});
test("should show replacement dropdown when disconnecting active search provider with alternatives", async ({
page,
}) => {
// Exa is active, Brave is also configured
const searchProviders = [
{ ...FAKE_SEARCH_PROVIDERS.exa },
{ ...FAKE_SEARCH_PROVIDERS.brave },
];
await mockWebSearchApis(page, searchProviders, []);
await page.goto(WEB_SEARCH_URL);
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
const exaCard = findProviderCard(page, "Exa");
await exaCard.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = exaCard.getByRole("button", {
name: "Disconnect Exa",
});
await expect(disconnectButton).toBeVisible();
await expect(disconnectButton).toBeEnabled();
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
await expect(confirmDialog).toContainText("Disconnect Exa");
// Should show replacement dropdown
await expect(
confirmDialog.getByText("Search history will be preserved")
).toBeVisible();
// Disconnect button should be enabled because first replacement is auto-selected
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await expect(confirmButton).toBeEnabled();
await expectElementScreenshot(confirmDialog, {
name: "web-search-disconnect-active-with-alt-modal",
});
});
test("should show connect message when disconnecting active search provider with no alternatives", async ({
page,
}) => {
// Only Exa configured and active
await mockWebSearchApis(page, [{ ...FAKE_SEARCH_PROVIDERS.exa }], []);
await page.goto(WEB_SEARCH_URL);
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
const exaCard = findProviderCard(page, "Exa");
await exaCard.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = exaCard.getByRole("button", {
name: "Disconnect Exa",
});
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
// Should show message about connecting another provider
await expect(
confirmDialog.getByText("Connect another provider")
).toBeVisible();
// Disconnect button should be enabled
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await expect(confirmButton).toBeEnabled();
await expectElementScreenshot(confirmDialog, {
name: "web-search-disconnect-no-alt-modal",
});
});
test("should not show disconnect button for unconfigured search provider", async ({
page,
}) => {
await mockWebSearchApis(page, [{ ...FAKE_SEARCH_PROVIDERS.exa }], []);
await page.goto(WEB_SEARCH_URL);
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
const braveCard = findProviderCard(page, "Brave");
await braveCard.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = braveCard.getByRole("button", {
name: "Disconnect Brave",
});
await expect(disconnectButton).not.toBeVisible();
await expectElementScreenshot(mainContainer(page), {
name: "web-search-disconnect-unconfigured",
});
});
});
test.describe("Web Crawler (Content) Providers", () => {
test("should disconnect a connected (non-active) content provider", async ({
page,
}) => {
// Firecrawl connected but not active, Exa is active
const contentProviders = [
{ ...FAKE_CONTENT_PROVIDERS.firecrawl, is_active: false },
{ ...FAKE_CONTENT_PROVIDERS.exa, is_active: true },
];
await mockWebSearchApis(page, [], contentProviders);
await page.goto(WEB_SEARCH_URL);
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
const firecrawlCard = findProviderCard(page, "Firecrawl");
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = firecrawlCard.getByRole("button", {
name: "Disconnect Firecrawl",
});
await expect(disconnectButton).toBeVisible();
await expect(disconnectButton).toBeEnabled();
// Mock the DELETE to succeed
await page.route(
"**/api/admin/web-search/content-providers/10",
async (route) => {
if (route.request().method() === "DELETE") {
await page.unroute("**/api/admin/web-search/content-providers");
await page.route(
"**/api/admin/web-search/content-providers",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({
status: 200,
json: [{ ...FAKE_CONTENT_PROVIDERS.exa, is_active: true }],
});
} else {
await route.continue();
}
}
);
await route.fulfill({ status: 200, json: {} });
} else {
await route.continue();
}
}
);
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
await expect(confirmDialog).toContainText("Disconnect Firecrawl");
await expectElementScreenshot(confirmDialog, {
name: "web-search-disconnect-content-non-active-modal",
});
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await confirmButton.click();
await expect(
firecrawlCard.getByRole("button", { name: "Connect" })
).toBeVisible({ timeout: 10000 });
});
test("should show replacement dropdown when disconnecting active content provider with alternatives", async ({
page,
}) => {
// Firecrawl is active, Exa is also configured
const contentProviders = [
{ ...FAKE_CONTENT_PROVIDERS.firecrawl },
{ ...FAKE_CONTENT_PROVIDERS.exa },
];
await mockWebSearchApis(page, [], contentProviders);
await page.goto(WEB_SEARCH_URL);
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
const firecrawlCard = findProviderCard(page, "Firecrawl");
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = firecrawlCard.getByRole("button", {
name: "Disconnect Firecrawl",
});
await disconnectButton.click();
const confirmDialog = page.getByRole("dialog");
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
// Should show replacement dropdown
await expect(
confirmDialog.getByText("Search history will be preserved")
).toBeVisible();
// Disconnect should be enabled because first replacement is auto-selected
const confirmButton = confirmDialog.getByRole("button", {
name: "Disconnect",
});
await expect(confirmButton).toBeEnabled();
await expectElementScreenshot(confirmDialog, {
name: "web-search-disconnect-content-active-with-alt-modal",
});
});
test("should not show disconnect for Onyx Web Crawler (built-in)", async ({
page,
}) => {
await mockWebSearchApis(page, [], []);
await page.goto(WEB_SEARCH_URL);
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
const onyxCard = findProviderCard(page, "Onyx Web Crawler");
await onyxCard.waitFor({ state: "visible", timeout: 10000 });
const disconnectButton = onyxCard.getByRole("button", {
name: "Disconnect Onyx Web Crawler",
});
await expect(disconnectButton).not.toBeVisible();
});
});
});