mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-27 10:32:41 +00:00
Compare commits
17 Commits
v3.1.0-clo
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf19d0df4f | ||
|
|
86a6a4c134 | ||
|
|
146b5449d2 | ||
|
|
b66991b5c5 | ||
|
|
9cb76dc027 | ||
|
|
f66891d19e | ||
|
|
c07c952ad5 | ||
|
|
be7f40a28a | ||
|
|
26f941b5da | ||
|
|
b9e84c42a8 | ||
|
|
0a1df52c2f | ||
|
|
306b0d452f | ||
|
|
5fdb34ba8e | ||
|
|
2d066631e3 | ||
|
|
5c84f6c61b | ||
|
|
899179d4b6 | ||
|
|
80d6bafc74 |
@@ -24,6 +24,16 @@ When hardcoding a boolean variable to a constant value, remove the variable enti
|
||||
|
||||
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
|
||||
|
||||
## Nginx Routing — New Backend Routes
|
||||
|
||||
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
|
||||
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
|
||||
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
|
||||
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
|
||||
- `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt)
|
||||
|
||||
Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`.
|
||||
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""remove voice_provider deleted column
|
||||
|
||||
Revision ID: 1d78c0ca7853
|
||||
Revises: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-26 11:30:53.883127
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d78c0ca7853"
|
||||
down_revision = "a3f8b2c1d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Hard-delete any soft-deleted rows before dropping the column
|
||||
op.execute("DELETE FROM voice_provider WHERE deleted = true")
|
||||
op.drop_column("voice_provider", "deleted")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"voice_provider",
|
||||
sa.Column(
|
||||
"deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
@@ -473,6 +473,8 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
@@ -105,9 +106,11 @@ def _get_slack_document_access(
|
||||
slack_connector: SlackConnector,
|
||||
channel_permissions: dict[str, ExternalAccess], # noqa: ARG001
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
indexing_start: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback
|
||||
callback=callback,
|
||||
start=indexing_start,
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
@@ -180,9 +183,15 @@ def slack_doc_sync(
|
||||
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.set_credentials_provider(provider)
|
||||
indexing_start_ts: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
yield from _get_slack_document_access(
|
||||
slack_connector,
|
||||
slack_connector=slack_connector,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
indexing_start=indexing_start_ts,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -40,10 +41,19 @@ def generic_doc_sync(
|
||||
|
||||
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
|
||||
|
||||
indexing_start: SecondsSinceUnixEpoch | None = (
|
||||
cc_pair.connector.indexing_start.timestamp()
|
||||
if cc_pair.connector.indexing_start is not None
|
||||
else None
|
||||
)
|
||||
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=indexing_start,
|
||||
callback=callback,
|
||||
):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -44,6 +44,31 @@ SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
|
||||
# Hard ceiling for the admin-configurable file upload size (in MB).
|
||||
# Self-hosted customers can raise or lower this via the environment variable.
|
||||
_raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250"))
|
||||
if _raw_max_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250",
|
||||
_raw_max_upload_size_mb,
|
||||
)
|
||||
_raw_max_upload_size_mb = 250
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb
|
||||
|
||||
# Default fallback for the per-user file upload size limit (in MB) when no
|
||||
# admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at
|
||||
# runtime so this never silently exceeds the hard ceiling.
|
||||
_raw_default_upload_size_mb = int(
|
||||
os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100")
|
||||
)
|
||||
if _raw_default_upload_size_mb < 0:
|
||||
logger.warning(
|
||||
"DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100",
|
||||
_raw_default_upload_size_mb,
|
||||
)
|
||||
_raw_default_upload_size_mb = 100
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
|
||||
) # 1 day
|
||||
@@ -61,17 +86,6 @@ CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -890,8 +890,8 @@ class ConfluenceConnector(
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
@@ -915,8 +915,8 @@ class ConfluenceConnector(
|
||||
self.confluence_client, doc_id, restrictions, ancestors
|
||||
) or space_level_access_info.get(page_space_key)
|
||||
|
||||
# Query pages
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
# Query pages (with optional time filtering for indexing_start)
|
||||
page_query = self._construct_page_cql_query(start, end)
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -950,7 +950,9 @@ class ConfluenceConnector(
|
||||
|
||||
# Query attachments for each page
|
||||
page_hierarchy_node_yielded = False
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
@@ -239,29 +240,53 @@ def enhanced_search_ids(
|
||||
)
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
def _bulk_fetch_request(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
|
||||
f"be decoded as JSON (response too large or truncated)."
|
||||
)
|
||||
raise
|
||||
|
||||
mid = len(issue_ids) // 2
|
||||
logger.warning(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in response["issues"]
|
||||
for issue in raw_issues
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1765,7 +1765,11 @@ class SharepointConnector(
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
def _fetch_slim_documents_from_sharepoint(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self._filter_excluded_sites(
|
||||
self.site_descriptors or self.fetch_sites()
|
||||
)
|
||||
@@ -1786,7 +1790,9 @@ class SharepointConnector(
|
||||
# Process site documents if flag is True
|
||||
if self.include_site_documents:
|
||||
for driveitem, drive_name, drive_web_url in self._fetch_driveitems(
|
||||
site_descriptor=site_descriptor
|
||||
site_descriptor=site_descriptor,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if self._is_driveitem_excluded(driveitem):
|
||||
logger.debug(f"Excluding by path denylist: {driveitem.web_url}")
|
||||
@@ -1841,7 +1847,9 @@ class SharepointConnector(
|
||||
|
||||
# Process site pages if flag is True
|
||||
if self.include_site_pages:
|
||||
site_pages = self._fetch_site_pages(site_descriptor)
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start, end=end
|
||||
)
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
@@ -2565,12 +2573,22 @@ class SharepointConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
|
||||
yield from self._fetch_slim_documents_from_sharepoint()
|
||||
start_dt = (
|
||||
datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
if start is not None
|
||||
else None
|
||||
)
|
||||
end_dt = (
|
||||
datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None
|
||||
)
|
||||
yield from self._fetch_slim_documents_from_sharepoint(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -516,6 +516,8 @@ def _get_all_doc_ids(
|
||||
] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
workspace_url: str | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -546,6 +548,8 @@ def _get_all_doc_ids(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
oldest=str(start) if start else None, # 0.0 -> None intentionally
|
||||
latest=str(end) if end is not None else None,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
@@ -847,8 +851,8 @@ class SlackConnector(
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
@@ -861,6 +865,8 @@ class SlackConnector(
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
callback=callback,
|
||||
workspace_url=self._workspace_url,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _load_from_checkpoint(
|
||||
|
||||
@@ -3135,8 +3135,6 @@ class VoiceProvider(Base):
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
@@ -119,10 +108,10 @@ def upsert_voice_provider(
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
|
||||
it is done automatically outside of this code.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- chunks: Document chunks with all of the information needed for
|
||||
indexing to the document index.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
This is often a batch operation including chunks from multiple
|
||||
documents.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -351,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -647,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata, # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
Groups chunks by document ID and for each document, deletes existing
|
||||
chunks and indexes the new chunks in bulk.
|
||||
@@ -673,29 +673,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document is newly indexed or had already existed and was just
|
||||
updated.
|
||||
"""
|
||||
# Group chunks by document ID.
|
||||
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
|
||||
list
|
||||
total_chunks = sum(
|
||||
cc.new_chunk_cnt
|
||||
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
|
||||
)
|
||||
for chunk in chunks:
|
||||
doc_id_to_chunks[chunk.source_document.id].append(chunk)
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
|
||||
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
|
||||
f"documents for index {self._index_name}."
|
||||
)
|
||||
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
# Try to index per-document.
|
||||
for _, chunks in doc_id_to_chunks.items():
|
||||
deleted_doc_ids: set[str] = set()
|
||||
# Buffer chunks per document as they arrive from the iterable.
|
||||
# When the document ID changes flush the buffered chunks.
|
||||
current_doc_id: str | None = None
|
||||
current_chunks: list[DocMetadataAwareIndexChunk] = []
|
||||
|
||||
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
|
||||
assert len(doc_chunks) > 0, "doc_chunks is empty"
|
||||
|
||||
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
|
||||
# Do this before deleting existing chunks to reduce the amount of
|
||||
# time the document index has no content for a given document, and
|
||||
# to reduce the chance of entering a state where we delete chunks,
|
||||
# then some error happens, and never successfully index new chunks.
|
||||
# Since we are doing this in batches, an error occurring midway
|
||||
# can result in a state where chunks are deleted and not all the
|
||||
# new chunks have been indexed.
|
||||
chunk_batch: list[DocumentChunk] = [
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk)
|
||||
for chunk in doc_chunks
|
||||
]
|
||||
onyx_document: Document = chunks[0].source_document
|
||||
onyx_document: Document = doc_chunks[0].source_document
|
||||
# First delete the doc's chunks from the index. This is so that
|
||||
# there are no dangling chunks in the index, in the event that the
|
||||
# new document's content contains fewer chunks than the previous
|
||||
@@ -704,22 +709,40 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# if the chunk count has actually decreased. This assumes that
|
||||
# overlapping chunks are perfectly overwritten. If we can't
|
||||
# guarantee that then we need the code as-is.
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed.
|
||||
document_insertion_record = DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
if onyx_document.id not in deleted_doc_ids:
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
deleted_doc_ids.add(onyx_document.id)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed. We record the result before bulk_index_documents
|
||||
# runs. If indexing raises, this entire result list is discarded
|
||||
# by the caller's retry logic, so early recording is safe.
|
||||
document_indexing_results.append(
|
||||
DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
document_indexing_results.append(document_insertion_record)
|
||||
|
||||
for chunk in chunks:
|
||||
doc_id = chunk.source_document.id
|
||||
if doc_id != current_doc_id:
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
current_doc_id = doc_id
|
||||
current_chunks = [chunk]
|
||||
else:
|
||||
current_chunks.append(chunk)
|
||||
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
|
||||
return document_indexing_results
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import time
|
||||
import urllib
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -318,7 +320,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
|
||||
@@ -338,22 +340,31 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
# Vespa has restrictions on valid characters, yet document IDs come from
|
||||
# external w.r.t. this class. We need to sanitize them.
|
||||
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
|
||||
clean_chunk_id_copy(chunk) for chunk in chunks
|
||||
]
|
||||
assert len(cleaned_chunks) == len(
|
||||
chunks
|
||||
), "Bug: Cleaned chunks and input chunks have different lengths."
|
||||
#
|
||||
# Instead of materializing all cleaned chunks upfront, we stream them
|
||||
# through a generator that cleans IDs and builds the original-ID mapping
|
||||
# incrementally as chunks flow into Vespa.
|
||||
def _clean_and_track(
|
||||
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
|
||||
id_map: dict[str, str],
|
||||
seen_ids: set[str],
|
||||
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
|
||||
"""Cleans chunk IDs and builds the original-ID mapping
|
||||
incrementally as chunks flow through, avoiding a separate
|
||||
materialization pass."""
|
||||
for chunk in chunks_iter:
|
||||
original_id = chunk.source_document.id
|
||||
cleaned = clean_chunk_id_copy(chunk)
|
||||
cleaned_id = cleaned.source_document.id
|
||||
# Needed so the final DocumentInsertionRecord returned can have
|
||||
# the original document ID. cleaned_chunks might not contain IDs
|
||||
# exactly as callers supplied them.
|
||||
id_map[cleaned_id] = original_id
|
||||
seen_ids.add(cleaned_id)
|
||||
yield cleaned
|
||||
|
||||
# Needed so the final DocumentInsertionRecord returned can have the
|
||||
# original document ID. cleaned_chunks might not contain IDs exactly as
|
||||
# callers supplied them.
|
||||
new_document_id_to_original_document_id: dict[str, str] = dict()
|
||||
for i, cleaned_chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[i]
|
||||
new_document_id_to_original_document_id[
|
||||
cleaned_chunk.source_document.id
|
||||
] = old_chunk.source_document.id
|
||||
new_document_id_to_original_document_id: dict[str, str] = {}
|
||||
all_cleaned_doc_ids: set[str] = set()
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
@@ -409,7 +420,13 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Insert new Vespa documents.
|
||||
# Insert new Vespa documents, streaming through the cleaning
|
||||
# pipeline so chunks are never fully materialized.
|
||||
cleaned_chunks = _clean_and_track(
|
||||
chunks,
|
||||
new_document_id_to_original_document_id,
|
||||
all_cleaned_doc_ids,
|
||||
)
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
@@ -419,10 +436,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids: set[str] = {
|
||||
chunk.source_document.id for chunk in cleaned_chunks
|
||||
}
|
||||
|
||||
return [
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -173,8 +174,10 @@ class UserFileIndexingAdapter:
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
token_count: int = (
|
||||
count_tokens(combined_content, llm_tokenizer)
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[str(user_file_id)] = token_count
|
||||
else:
|
||||
|
||||
@@ -175,6 +175,32 @@ def get_tokenizer(
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
|
||||
|
||||
# Max characters per encode() call.
|
||||
_ENCODE_CHUNK_SIZE = 500_000
|
||||
|
||||
|
||||
def count_tokens(
|
||||
text: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
token_limit: int | None = None,
|
||||
) -> int:
|
||||
"""Count tokens, chunking the input to avoid tiktoken stack overflow.
|
||||
|
||||
If token_limit is provided and the text is large enough to require
|
||||
multiple chunks (> 500k chars), stops early once the count exceeds it.
|
||||
When early-exiting, the returned value exceeds token_limit but may be
|
||||
less than the true full token count.
|
||||
"""
|
||||
if len(text) <= _ENCODE_CHUNK_SIZE:
|
||||
return len(tokenizer.encode(text))
|
||||
total = 0
|
||||
for start in range(0, len(text), _ENCODE_CHUNK_SIZE):
|
||||
total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE]))
|
||||
if token_limit is not None and total > token_limit:
|
||||
return total # Already over — skip remaining chunks
|
||||
return total
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
|
||||
@@ -44,11 +44,12 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,19 +142,11 @@ def _validate_endpoint(
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
# Any timeout (connect, read, or write) means the configured timeout_seconds
|
||||
# is too low for this endpoint. Report as timeout so the UI directs the user
|
||||
# to increase the timeout setting.
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
"Hook endpoint validation: timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
@@ -9,20 +9,15 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD
|
||||
from shared_configs.configs import SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -161,8 +156,8 @@ def categorize_uploaded_files(
|
||||
document formats (.pdf, .docx, …) and falls back to a text-detection
|
||||
heuristic for unknown extensions (.py, .js, .rs, …).
|
||||
- Uses default tokenizer to compute token length.
|
||||
- If token length > threshold, reject file (unless threshold skip is enabled).
|
||||
- If text cannot be extracted, reject file.
|
||||
- If token length exceeds the admin-configured threshold, reject file.
|
||||
- If extension unsupported or text cannot be extracted, reject file.
|
||||
- Otherwise marked as acceptable.
|
||||
"""
|
||||
|
||||
@@ -173,36 +168,33 @@ def categorize_uploaded_files(
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
# Check global skip flag (works for both single-tenant and multi-tenant)
|
||||
if SKIP_USERFILE_THRESHOLD:
|
||||
skip_threshold = True
|
||||
logger.info("Skipping userfile threshold check (global setting)")
|
||||
# Check tenant-specific skip list (only applicable in multi-tenant)
|
||||
elif MULTI_TENANT and SKIP_USERFILE_THRESHOLD_TENANT_LIST:
|
||||
try:
|
||||
current_tenant_id = get_current_tenant_id()
|
||||
skip_threshold = current_tenant_id in SKIP_USERFILE_THRESHOLD_TENANT_LIST
|
||||
if skip_threshold:
|
||||
logger.info(
|
||||
f"Skipping userfile threshold check for tenant: {current_tenant_id}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to get current tenant ID: {str(e)}")
|
||||
# Derive limits from admin-configurable settings.
|
||||
# For upload size: load_settings() resolves 0/None to a positive default.
|
||||
# For token threshold: 0 means "no limit" (converted to None below).
|
||||
settings = load_settings()
|
||||
max_upload_size_mb = (
|
||||
settings.user_file_max_upload_size_mb
|
||||
) # always positive after load_settings()
|
||||
max_upload_size_bytes = (
|
||||
max_upload_size_mb * 1024 * 1024 if max_upload_size_mb else None
|
||||
)
|
||||
token_threshold_k = settings.file_token_count_threshold_k
|
||||
token_threshold = (
|
||||
token_threshold_k * 1000 if token_threshold_k else None
|
||||
) # 0 → None = no limit
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
# Size limit is a hard safety cap.
|
||||
if max_upload_size_bytes is not None and is_upload_too_large(
|
||||
upload, max_upload_size_bytes
|
||||
):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
reason=f"Exceeds {max_upload_size_mb} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
@@ -224,11 +216,11 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -269,12 +261,14 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
token_count = len(tokenizer.encode(text_content))
|
||||
if not skip_threshold and token_count > FILE_TOKEN_COUNT_THRESHOLD:
|
||||
token_count = count_tokens(
|
||||
text_content, tokenizer, token_limit=token_threshold
|
||||
)
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {FILE_TOKEN_COUNT_THRESHOLD} token limit",
|
||||
reason=f"Exceeds {token_threshold_k}K token limit",
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1524,6 +1524,7 @@ def get_bifrost_available_models(
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -463,3 +463,4 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
display_name: str # Human-readable name from Bifrost API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -9,7 +9,9 @@ from onyx import __version__ as onyx_version
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
@@ -17,10 +19,16 @@ from onyx.db.models import User
|
||||
from onyx.db.notification import dismiss_all_notifications
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Notification
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.models import UserSettings
|
||||
@@ -41,6 +49,15 @@ basic_router = APIRouter(prefix="/settings")
|
||||
def admin_put_settings(
|
||||
settings: Settings, _: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb is not None
|
||||
and settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"File upload size limit cannot exceed {MAX_ALLOWED_UPLOAD_SIZE_MB} MB",
|
||||
)
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
@@ -83,6 +100,16 @@ def fetch_settings(
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
),
|
||||
default_file_token_count_threshold_k=(
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,12 +2,19 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB = 200
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB = 10000
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -78,7 +85,12 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
user_file_max_upload_size_mb: int | None = Field(
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
@@ -108,3 +120,14 @@ class UserSettings(Settings):
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
# Hard ceiling for user_file_max_upload_size_mb, derived from env var.
|
||||
max_allowed_upload_size_mb: int = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
# Factory defaults so the frontend can show a "restore default" button.
|
||||
default_user_file_max_upload_size_mb: int = DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
default_file_token_count_threshold_k: int = Field(
|
||||
default_factory=lambda: (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -51,9 +57,36 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
|
||||
# Resolve context-aware defaults for token threshold.
|
||||
# None = admin hasn't set a value yet → use context-aware default.
|
||||
# 0 = admin explicitly chose "no limit" → preserve as-is.
|
||||
if settings.file_token_count_threshold_k is None:
|
||||
settings.file_token_count_threshold_k = (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
if DISABLE_VECTOR_DB
|
||||
else DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
# Upload size: 0 and None are treated as "unset" (not "no limit") →
|
||||
# fall back to min(configured default, hard ceiling).
|
||||
if not settings.user_file_max_upload_size_mb:
|
||||
settings.user_file_max_upload_size_mb = min(
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB,
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
)
|
||||
|
||||
# Clamp to env ceiling so stale KV values are capped even if the
|
||||
# operator lowered MAX_ALLOWED_UPLOAD_SIZE_MB after a higher value
|
||||
# was already saved (api.py only guards new writes).
|
||||
if (
|
||||
settings.user_file_max_upload_size_mb > 0
|
||||
and settings.user_file_max_upload_size_mb > MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
):
|
||||
settings.user_file_max_upload_size_mb = MAX_ALLOWED_UPLOAD_SIZE_MB
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -191,25 +191,6 @@ IGNORED_SYNCING_TENANT_LIST = (
|
||||
else None
|
||||
)
|
||||
|
||||
# Global flag to skip userfile threshold for all users/tenants
|
||||
SKIP_USERFILE_THRESHOLD = (
|
||||
os.environ.get("SKIP_USERFILE_THRESHOLD", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Comma-separated list of specific tenant IDs to skip threshold (multi-tenant only)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS = os.environ.get(
|
||||
"SKIP_USERFILE_THRESHOLD_TENANT_IDS"
|
||||
)
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_LIST = (
|
||||
[
|
||||
tenant.strip()
|
||||
for tenant in SKIP_USERFILE_THRESHOLD_TENANT_IDS.split(",")
|
||||
if tenant.strip()
|
||||
]
|
||||
if SKIP_USERFILE_THRESHOLD_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
ENVIRONMENT = os.environ.get("ENVIRONMENT") or "not_explicitly_set"
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -17,6 +19,10 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
# Predates any test workspace messages, so the result set should match
|
||||
# the "no start time" case while exercising the oldest= parameter.
|
||||
OLDEST_TS_2016 = datetime(2016, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@@ -105,15 +111,17 @@ def test_load_from_checkpoint_access__private_channel(
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.parametrize("start_ts", [None, OLDEST_TS_2016])
|
||||
def test_slim_documents_access__public_channel(
|
||||
slack_connector: SlackConnector,
|
||||
start_ts: float | None,
|
||||
) -> None:
|
||||
"""Test that retrieve_all_slim_docs_perm_sync returns correct access information for slim documents."""
|
||||
if not slack_connector.client:
|
||||
raise RuntimeError("Web client must be defined")
|
||||
|
||||
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=0.0,
|
||||
start=start_ts,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
@@ -149,7 +157,7 @@ def test_slim_documents_access__private_channel(
|
||||
raise RuntimeError("Web client must be defined")
|
||||
|
||||
slim_docs_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=0.0,
|
||||
start=None,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
@@ -103,6 +103,11 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="no yuhong allowed",
|
||||
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
|
||||
from tests.external_dependency_unit.document_index.conftest import make_chunk
|
||||
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
|
||||
assert len(result_map) == 2
|
||||
assert result_map[existing_doc] is True
|
||||
assert result_map[new_doc] is False
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
|
||||
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk(doc_id, chunk_id=i)
|
||||
|
||||
results = document_index.index(
|
||||
chunks=chunk_gen(), indexing_metadata=metadata
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == doc_id
|
||||
assert results[0].already_existed is False
|
||||
|
||||
@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == 0
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk("test_doc_gen", chunk_id=i)
|
||||
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
|
||||
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
|
||||
tenant_id=get_current_tenant_id(),
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
results = document_index.index(chunk_gen(), index_batch_params)
|
||||
|
||||
assert len(results) == 1
|
||||
record = results.pop()
|
||||
assert record.document_id == "test_doc_gen"
|
||||
assert record.already_existed is False
|
||||
|
||||
147
backend/tests/unit/onyx/connectors/jira/test_jira_bulk_fetch.py
Normal file
147
backend/tests/unit/onyx/connectors/jira/test_jira_bulk_fetch.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": issue_id,
|
||||
"key": f"TEST-{issue_id}",
|
||||
"fields": {"summary": f"Issue {issue_id}"},
|
||||
}
|
||||
|
||||
|
||||
def _mock_jira_client() -> MagicMock:
|
||||
mock = MagicMock(spec=JIRA)
|
||||
mock._options = {"server": "https://jira.example.com"}
|
||||
mock._session = MagicMock()
|
||||
mock._get_url = MagicMock(
|
||||
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def test_bulk_fetch_success() -> None:
|
||||
"""Happy path: all issues fetched in one request."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3"])
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(r, Issue) for r in result)
|
||||
client._session.post.assert_called_once()
|
||||
|
||||
|
||||
def test_bulk_fetch_splits_on_json_error() -> None:
|
||||
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if len(ids) > 2:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 2294125
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
|
||||
assert len(result) == 4
|
||||
returned_ids = {r.raw["id"] for r in result}
|
||||
assert returned_ids == {"1", "2", "3", "4"}
|
||||
assert call_count > 1
|
||||
|
||||
|
||||
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
|
||||
"""A single issue that always fails JSON decode raises after splitting."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if "bad" in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 100
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "bad", "2"])
|
||||
|
||||
|
||||
def test_bulk_fetch_non_json_error_propagates() -> None:
|
||||
"""Non-JSONDecodeError exceptions still propagate."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = ValueError("something else broke")
|
||||
client._session.post.return_value = resp
|
||||
|
||||
try:
|
||||
bulk_fetch_issues(client, ["1"])
|
||||
assert False, "Expected ValueError to propagate"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_bulk_fetch_with_fields() -> None:
|
||||
"""Fields parameter is forwarded correctly."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
bulk_fetch_issues(client, ["1"], fields="summary,description")
|
||||
|
||||
call_payload = client._session.post.call_args[1]["json"]
|
||||
assert call_payload["fields"] == ["summary", "description"]
|
||||
|
||||
|
||||
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
|
||||
client = _mock_jira_client()
|
||||
bad_id = "BAD"
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if bad_id in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"truncated", "doc", 999
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -31,6 +33,7 @@ def mock_jira_cc_pair(
|
||||
"jira_base_url": jira_base_url,
|
||||
"project_key": project_key,
|
||||
}
|
||||
mock_cc_pair.connector.indexing_start = None
|
||||
|
||||
return mock_cc_pair
|
||||
|
||||
@@ -65,3 +68,75 @@ def test_jira_permission_sync(
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
):
|
||||
print(doc)
|
||||
|
||||
|
||||
def test_jira_doc_sync_passes_indexing_start(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
"""Verify that generic_doc_sync derives indexing_start from cc_pair
|
||||
and forwards it to retrieve_all_slim_docs_perm_sync."""
|
||||
indexing_start_dt = datetime(2025, 6, 1, tzinfo=timezone.utc)
|
||||
mock_jira_cc_pair.connector.indexing_start = indexing_start_dt
|
||||
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
assert jira_connector._jira_client is not None
|
||||
jira_connector._jira_client._options = MagicMock()
|
||||
jira_connector._jira_client._options.return_value = {
|
||||
"rest_api_version": JIRA_SERVER_API_VERSION
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(jira_connector),
|
||||
"retrieve_all_slim_docs_perm_sync",
|
||||
return_value=iter([]),
|
||||
) as mock_retrieve:
|
||||
list(
|
||||
jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
)
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args
|
||||
assert call_kwargs.kwargs["start"] == indexing_start_dt.timestamp()
|
||||
|
||||
|
||||
def test_jira_doc_sync_passes_none_when_no_indexing_start(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
"""Verify that indexing_start is None when the connector has no indexing_start set."""
|
||||
mock_jira_cc_pair.connector.indexing_start = None
|
||||
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
assert jira_connector._jira_client is not None
|
||||
jira_connector._jira_client._options = MagicMock()
|
||||
jira_connector._jira_client._options.return_value = {
|
||||
"rest_api_version": JIRA_SERVER_API_VERSION
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
type(jira_connector),
|
||||
"retrieve_all_slim_docs_perm_sync",
|
||||
return_value=iter([]),
|
||||
) as mock_retrieve:
|
||||
list(
|
||||
jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
)
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
call_kwargs = mock_retrieve.call_args
|
||||
assert call_kwargs.kwargs["start"] is None
|
||||
|
||||
@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
assert provider.deleted is True
|
||||
mock_db_session.delete.assert_called_once_with(provider)
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,13 +158,11 @@ class TestValidateEndpoint:
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -4,13 +4,23 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.natural_language_processing import utils as nlp_utils
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.server.features.projects import projects_file_utils as utils
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
class _Tokenizer:
|
||||
class _Tokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return [1] * len(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class _NonSeekableFile(BytesIO):
|
||||
def tell(self) -> int:
|
||||
@@ -29,10 +39,26 @@ def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
|
||||
return UploadFile(filename=filename, file=BytesIO(content), size=None)
|
||||
|
||||
|
||||
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _make_settings(upload_size_mb: int = 1, token_threshold_k: int = 100) -> Settings:
|
||||
return Settings(
|
||||
user_file_max_upload_size_mb=upload_size_mb,
|
||||
file_token_count_threshold_k=token_threshold_k,
|
||||
)
|
||||
|
||||
|
||||
def _patch_common_dependencies(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
upload_size_mb: int = 1,
|
||||
token_threshold_k: int = 100,
|
||||
) -> None:
|
||||
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
|
||||
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
|
||||
monkeypatch.setattr(
|
||||
utils,
|
||||
"load_settings",
|
||||
lambda: _make_settings(upload_size_mb, token_threshold_k),
|
||||
)
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
|
||||
@@ -76,9 +102,8 @@ def test_is_upload_too_large_logs_warning_when_size_unknown(
|
||||
def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
# upload_size_mb=1 → max_bytes = 1*1024*1024; file size 99 is well under
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("small.png", size=99)
|
||||
@@ -91,9 +116,7 @@ def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload_no_size("small.png", content=b"x" * 99)
|
||||
@@ -106,12 +129,11 @@ def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("edge.png", size=100)
|
||||
# 1 MB = 1048576 bytes; file at exactly that boundary should be accepted
|
||||
upload = _make_upload("edge.png", size=1048576)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
@@ -121,12 +143,10 @@ def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("large.png", size=101)
|
||||
upload = _make_upload("large.png", size=1048577) # 1 byte over 1 MB
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
@@ -137,13 +157,11 @@ def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
small = _make_upload("small.png", size=50)
|
||||
large = _make_upload("large.png", size=101)
|
||||
large = _make_upload("large.png", size=1048577)
|
||||
|
||||
result = utils.categorize_uploaded_files([small, large], MagicMock())
|
||||
|
||||
@@ -153,15 +171,12 @@ def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
|
||||
def test_categorize_uploaded_files_enforces_size_limit_always(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
|
||||
upload = _make_upload("oversized.pdf", size=101)
|
||||
upload = _make_upload("oversized.pdf", size=1048577)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
@@ -172,14 +187,12 @@ def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_sk
|
||||
def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
|
||||
extract_mock = MagicMock(return_value="this should not run")
|
||||
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
|
||||
|
||||
oversized_doc = _make_upload("oversized.pdf", size=101)
|
||||
oversized_doc = _make_upload("oversized.pdf", size=1048577)
|
||||
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
|
||||
|
||||
extract_mock.assert_not_called()
|
||||
@@ -188,40 +201,219 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_python_file(
|
||||
def test_categorize_enforces_size_limit_when_upload_size_mb_is_positive(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
"""A positive upload_size_mb is always enforced."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
py_source = b'def hello():\n print("world")\n'
|
||||
monkeypatch.setattr(
|
||||
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
|
||||
)
|
||||
|
||||
upload = _make_upload("script.py", size=len(py_source), content=py_source)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable[0].filename == "script.py"
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_rejects_binary_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
|
||||
|
||||
binary_content = bytes(range(256)) * 4
|
||||
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
|
||||
upload = _make_upload("huge.png", size=1048577, content=b"x")
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].filename == "data.bin"
|
||||
assert "Unsupported file type" in result.rejected[0].reason
|
||||
|
||||
|
||||
def test_categorize_enforces_token_limit_when_threshold_k_is_positive(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A positive token_threshold_k is always enforced."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=5)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
|
||||
|
||||
upload = _make_upload("big_image.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
|
||||
|
||||
def test_categorize_no_token_limit_when_threshold_k_is_zero(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""token_threshold_k=0 means no token limit; high-token files are accepted."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=0)
|
||||
monkeypatch.setattr(
|
||||
utils, "estimate_image_tokens_for_upload", lambda _upload: 999_999
|
||||
)
|
||||
|
||||
upload = _make_upload("huge_image.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 0
|
||||
assert len(result.acceptable) == 1
|
||||
|
||||
|
||||
def test_categorize_both_limits_enforced(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Both positive limits are enforced; file exceeding token limit is rejected."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=10, token_threshold_k=5)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 6000)
|
||||
|
||||
upload = _make_upload("over_tokens.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 5K token limit"
|
||||
|
||||
|
||||
def test_categorize_rejection_reason_contains_dynamic_values(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Rejection reasons reflect the admin-configured limits, not hardcoded values."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 8000)
|
||||
|
||||
# File within size limit but over token limit
|
||||
upload = _make_upload("tokens.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert result.rejected[0].reason == "Exceeds 7K token limit"
|
||||
|
||||
# File over size limit
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=42, token_threshold_k=7)
|
||||
oversized = _make_upload("big.png", size=42 * 1024 * 1024 + 1)
|
||||
result2 = utils.categorize_uploaded_files([oversized], MagicMock())
|
||||
|
||||
assert result2.rejected[0].reason == "Exceeds 42 MB file size limit"
|
||||
|
||||
|
||||
# --- count_tokens tests ---
|
||||
|
||||
|
||||
def test_count_tokens_small_text() -> None:
|
||||
"""Small text should be encoded in a single call and return correct count."""
|
||||
tokenizer = _Tokenizer()
|
||||
text = "hello world"
|
||||
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
|
||||
|
||||
|
||||
def test_count_tokens_chunked_matches_single_call() -> None:
|
||||
"""Chunked encoding should produce the same result as single-call for small text."""
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 1000
|
||||
assert count_tokens(text, tokenizer) == len(tokenizer.encode(text))
|
||||
|
||||
|
||||
def test_count_tokens_large_text_is_chunked(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Text exceeding _ENCODE_CHUNK_SIZE should be split into multiple encode calls."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 250
|
||||
# _Tokenizer returns 1 token per char, so total should be 250
|
||||
assert count_tokens(text, tokenizer) == 250
|
||||
|
||||
|
||||
def test_count_tokens_with_token_limit_exits_early(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When token_limit is set and exceeded, count_tokens should stop early."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
|
||||
encode_call_count = 0
|
||||
original_tokenizer = _Tokenizer()
|
||||
|
||||
class _CountingTokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
nonlocal encode_call_count
|
||||
encode_call_count += 1
|
||||
return original_tokenizer.encode(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
tokenizer = _CountingTokenizer()
|
||||
# 500 chars → 5 chunks of 100; limit=150 → should stop after 2 chunks
|
||||
text = "a" * 500
|
||||
result = count_tokens(text, tokenizer, token_limit=150)
|
||||
|
||||
assert result == 200 # 2 chunks × 100 tokens each
|
||||
assert encode_call_count == 2, "Should have stopped after 2 chunks"
|
||||
|
||||
|
||||
def test_count_tokens_with_token_limit_not_exceeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When token_limit is set but not exceeded, all chunks are encoded."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 250
|
||||
result = count_tokens(text, tokenizer, token_limit=1000)
|
||||
assert result == 250
|
||||
|
||||
|
||||
def test_count_tokens_no_limit_encodes_all_chunks(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Without token_limit, all chunks are encoded regardless of count."""
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
tokenizer = _Tokenizer()
|
||||
text = "a" * 500
|
||||
result = count_tokens(text, tokenizer)
|
||||
assert result == 500
|
||||
|
||||
|
||||
# --- early exit via token_limit in categorize tests ---
|
||||
|
||||
|
||||
def test_categorize_early_exits_tokenization_for_large_text(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Large text files should be rejected via early-exit tokenization
|
||||
without encoding all chunks."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
# token_threshold = 1000; _ENCODE_CHUNK_SIZE = 100 → text of 500 chars = 5 chunks
|
||||
# Should stop after 2nd chunk (200 tokens > 1000? No... need 1 token per char)
|
||||
# With _Tokenizer: 1 token per char. threshold=1000, chunk=100 → need 11 chunks
|
||||
# Let's use a bigger text
|
||||
monkeypatch.setattr(nlp_utils, "_ENCODE_CHUNK_SIZE", 100)
|
||||
large_text = "x" * 5000 # 5000 tokens, threshold 1000
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: large_text)
|
||||
|
||||
encode_call_count = 0
|
||||
original_tokenizer = _Tokenizer()
|
||||
|
||||
class _CountingTokenizer(BaseTokenizer):
|
||||
def encode(self, text: str) -> list[int]:
|
||||
nonlocal encode_call_count
|
||||
encode_call_count += 1
|
||||
return original_tokenizer.encode(text)
|
||||
|
||||
def tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def decode(self, _tokens: list[int]) -> str:
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _CountingTokenizer())
|
||||
|
||||
upload = _make_upload("big.txt", size=5000, content=large_text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.rejected) == 1
|
||||
assert "token limit" in result.rejected[0].reason
|
||||
# 5000 chars / 100 chunk_size = 50 chunks total; should stop well before all 50
|
||||
assert (
|
||||
encode_call_count < 50
|
||||
), f"Expected early exit but encoded {encode_call_count} chunks out of 50"
|
||||
|
||||
|
||||
def test_categorize_text_under_token_limit_accepted(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Text files under the token threshold should be accepted with exact count."""
|
||||
_patch_common_dependencies(monkeypatch, upload_size_mb=1000, token_threshold_k=1)
|
||||
small_text = "x" * 500 # 500 tokens < 1000 threshold
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: small_text)
|
||||
|
||||
upload = _make_upload("ok.txt", size=500, content=small_text.encode())
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable_file_to_token_count["ok.txt"] == 500
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings import store as settings_store
|
||||
from onyx.server.settings.models import (
|
||||
DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB,
|
||||
)
|
||||
from onyx.server.settings.models import DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
class _FakeKvStore:
|
||||
def __init__(self, data: dict | None = None) -> None:
|
||||
self._data = data
|
||||
|
||||
def load(self, _key: str) -> dict:
|
||||
raise KvKeyNotFoundError()
|
||||
if self._data is None:
|
||||
raise KvKeyNotFoundError()
|
||||
return self._data
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
@@ -20,13 +31,140 @@ class _FakeCache:
|
||||
self._vals[key] = value.encode("utf-8")
|
||||
|
||||
|
||||
def test_load_settings_includes_user_file_max_upload_size_mb(
|
||||
def test_load_settings_uses_model_defaults_when_no_stored_value(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When no settings are stored (vector DB enabled), load_settings() should
|
||||
resolve the default token threshold to 200."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", False)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 77
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
assert (
|
||||
settings.file_token_count_threshold_k
|
||||
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_VECTOR_DB
|
||||
)
|
||||
|
||||
|
||||
def test_load_settings_uses_high_token_default_when_vector_db_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When vector DB is disabled and no settings are stored, the token
|
||||
threshold should default to 10000 (10M tokens)."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
assert (
|
||||
settings.file_token_count_threshold_k
|
||||
== DEFAULT_FILE_TOKEN_COUNT_THRESHOLD_K_NO_VECTOR_DB
|
||||
)
|
||||
|
||||
|
||||
def test_load_settings_preserves_explicit_value_when_vector_db_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When vector DB is disabled but admin explicitly set a token threshold,
|
||||
that value should be preserved (not overridden by the 10000 default)."""
|
||||
stored = Settings(file_token_count_threshold_k=500).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.file_token_count_threshold_k == 500
|
||||
|
||||
|
||||
def test_load_settings_preserves_zero_token_threshold(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 means 'no limit' and should be preserved."""
|
||||
stored = Settings(file_token_count_threshold_k=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DISABLE_VECTOR_DB", True)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.file_token_count_threshold_k == 0
|
||||
|
||||
|
||||
def test_load_settings_resolves_zero_upload_size_to_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 should be treated as unset and resolved to the default."""
|
||||
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
|
||||
|
||||
def test_load_settings_clamps_upload_size_to_env_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the stored upload size exceeds MAX_ALLOWED_UPLOAD_SIZE_MB, it should
|
||||
be clamped to the env-configured maximum."""
|
||||
stored = Settings(user_file_max_upload_size_mb=500).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 250
|
||||
|
||||
|
||||
def test_load_settings_preserves_upload_size_within_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When the stored upload size is within MAX_ALLOWED_UPLOAD_SIZE_MB, it should
|
||||
be preserved unchanged."""
|
||||
stored = Settings(user_file_max_upload_size_mb=150).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 250)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 150
|
||||
|
||||
|
||||
def test_load_settings_zero_upload_size_resolves_to_default(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A value of 0 should be treated as unset and resolved to the default,
|
||||
clamped to MAX_ALLOWED_UPLOAD_SIZE_MB."""
|
||||
stored = Settings(user_file_max_upload_size_mb=0).model_dump()
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore(stored))
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 100)
|
||||
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 100
|
||||
|
||||
|
||||
def test_load_settings_default_clamped_to_max(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""When DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB exceeds MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
the effective default should be min(DEFAULT, MAX)."""
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", 100)
|
||||
monkeypatch.setattr(settings_store, "MAX_ALLOWED_UPLOAD_SIZE_MB", 50)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 50
|
||||
|
||||
@@ -39,6 +39,22 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
|
||||
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
|
||||
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -39,6 +39,20 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -39,6 +39,23 @@ server {
|
||||
# Conditionally include MCP location configuration
|
||||
include /etc/nginx/conf.d/mcp.conf.inc;
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT}s;
|
||||
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}s;
|
||||
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
# Match both /api/* and /openapi.json in a single rule
|
||||
location ~ ^/(api|openapi.json)(/.*)?$ {
|
||||
# Rewrite /api prefixed matched paths
|
||||
|
||||
@@ -66,10 +66,3 @@ DB_READONLY_PASSWORD=password
|
||||
# Show extra/uncommon connectors
|
||||
# See https://docs.onyx.app/admins/connectors/overview for a full list of connectors
|
||||
SHOW_EXTRA_CONNECTORS=False
|
||||
|
||||
# User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
#SKIP_USERFILE_THRESHOLD=false
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
#SKIP_USERFILE_THRESHOLD_TENANT_IDS=
|
||||
|
||||
@@ -35,6 +35,10 @@ USER_AUTH_SECRET=""
|
||||
|
||||
## Chat Configuration
|
||||
# HARD_DELETE_CHATS=
|
||||
# MAX_ALLOWED_UPLOAD_SIZE_MB=250
|
||||
# Default per-user upload size limit (MB) when no admin value is set.
|
||||
# Automatically clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at runtime.
|
||||
# DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=100
|
||||
|
||||
## Base URL for redirects
|
||||
# WEB_DOMAIN=
|
||||
@@ -42,13 +46,6 @@ USER_AUTH_SECRET=""
|
||||
## Enterprise Features, requires a paid plan and licenses
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false
|
||||
|
||||
## User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
# SKIP_USERFILE_THRESHOLD=false
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
# SKIP_USERFILE_THRESHOLD_TENANT_IDS=
|
||||
|
||||
|
||||
################################################################################
|
||||
## SERVICES CONFIGURATIONS
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.36
|
||||
version: 0.4.38
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
|
||||
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docfetching.replicaCount) 0) }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching-metrics
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
metrics: "true"
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 9092
|
||||
targetPort: metrics
|
||||
protocol: TCP
|
||||
name: metrics
|
||||
selector:
|
||||
{{- include "onyx.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -73,6 +73,10 @@ spec:
|
||||
"-Q",
|
||||
"connector_doc_fetching",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9092
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_docfetching.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
|
||||
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_docprocessing.replicaCount) 0) }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing-metrics
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
metrics: "true"
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 9093
|
||||
targetPort: metrics
|
||||
protocol: TCP
|
||||
name: metrics
|
||||
selector:
|
||||
{{- include "onyx.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_docprocessing.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docprocessing.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -73,6 +73,10 @@ spec:
|
||||
"-Q",
|
||||
"docprocessing",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9093
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_docprocessing.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
{{- /* Metrics port must match the default in metrics_server.py (_DEFAULT_PORTS).
|
||||
Do NOT use PROMETHEUS_METRICS_PORT env var in Helm — each worker needs its own port. */ -}}
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring-metrics
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
metrics: "true"
|
||||
spec:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
- port: 9096
|
||||
targetPort: metrics
|
||||
protocol: TCP
|
||||
name: metrics
|
||||
selector:
|
||||
{{- include "onyx.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.celery_worker_monitoring.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_monitoring.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -70,6 +70,10 @@ spec:
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
ports:
|
||||
- name: metrics
|
||||
containerPort: 9096
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_monitoring.resources | nindent 12 }}
|
||||
envFrom:
|
||||
|
||||
@@ -63,6 +63,22 @@ data:
|
||||
}
|
||||
{{- end }}
|
||||
|
||||
location ~ ^/scim(/.*)?$ {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
# timeout settings
|
||||
proxy_connect_timeout {{ .Values.nginx.timeouts.connect }}s;
|
||||
proxy_send_timeout {{ .Values.nginx.timeouts.send }}s;
|
||||
proxy_read_timeout {{ .Values.nginx.timeouts.read }}s;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
location ~ ^/(api|openapi\.json)(/.*)?$ {
|
||||
rewrite ^/api(/.*)$ $1 break;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
|
||||
@@ -282,7 +282,7 @@ nginx:
|
||||
# The ingress-nginx subchart doesn't auto-detect our custom ConfigMap changes.
|
||||
# Workaround: Helm upgrade will restart if the following annotation value changes.
|
||||
podAnnotations:
|
||||
onyx.app/nginx-config-version: "2"
|
||||
onyx.app/nginx-config-version: "3"
|
||||
|
||||
# Propagate DOMAIN into nginx so server_name continues to use the same env var
|
||||
extraEnvs:
|
||||
@@ -1285,11 +1285,5 @@ configMap:
|
||||
DOMAIN: "localhost"
|
||||
# Chat Configs
|
||||
HARD_DELETE_CHATS: ""
|
||||
# User File Upload Configuration
|
||||
# Skip the token count threshold check (100,000 tokens) for uploaded files
|
||||
# For self-hosted: set to true to skip for all users
|
||||
SKIP_USERFILE_THRESHOLD: ""
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
|
||||
# Maximum user upload file size in MB for chat/projects uploads
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
MAX_ALLOWED_UPLOAD_SIZE_MB: ""
|
||||
DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
|
||||
@@ -28,7 +28,7 @@ Some commands require external tools to be installed and configured:
|
||||
- **uv** - Required for `backend` commands
|
||||
- Install from [docs.astral.sh/uv](https://docs.astral.sh/uv/)
|
||||
|
||||
- **GitHub CLI** (`gh`) - Required for `run-ci` and `cherry-pick` commands
|
||||
- **GitHub CLI** (`gh`) - Required for `run-ci`, `cherry-pick`, and `trace` commands
|
||||
- Install from [cli.github.com](https://cli.github.com/)
|
||||
- Authenticate with `gh auth login`
|
||||
|
||||
@@ -412,6 +412,62 @@ The `compare` subcommand writes a `summary.json` alongside the report with aggre
|
||||
counts (changed, added, removed, unchanged). The HTML report is only generated when
|
||||
visual differences are detected.
|
||||
|
||||
### `trace` - View Playwright Traces from CI
|
||||
|
||||
Download Playwright trace artifacts from a GitHub Actions run and open them
|
||||
with `playwright show-trace`. Traces are only generated for failing tests
|
||||
(`retain-on-failure`).
|
||||
|
||||
```shell
|
||||
ods trace [run-id-or-url]
|
||||
```
|
||||
|
||||
The run can be specified as a numeric run ID, a full GitHub Actions URL, or
|
||||
omitted to find the latest Playwright run for the current branch.
|
||||
|
||||
**Flags:**
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `--branch`, `-b` | | Find latest run for this branch |
|
||||
| `--pr` | | Find latest run for this PR number |
|
||||
| `--project`, `-p` | | Filter to a specific project (`admin`, `exclusive`, `lite`) |
|
||||
| `--list`, `-l` | `false` | List available traces without opening |
|
||||
| `--no-open` | `false` | Download traces but don't open them |
|
||||
|
||||
When multiple traces are found, an interactive picker lets you select which
|
||||
traces to open. Use arrow keys or `j`/`k` to navigate, `space` to toggle,
|
||||
`a` to select all, `n` to deselect all, and `enter` to open. Falls back to a
|
||||
plain-text prompt when no TTY is available.
|
||||
|
||||
Downloaded artifacts are cached in `/tmp/ods-traces/<run-id>/` so repeated
|
||||
invocations for the same run are instant.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```shell
|
||||
# Latest run for the current branch
|
||||
ods trace
|
||||
|
||||
# Specific run ID
|
||||
ods trace 12345678
|
||||
|
||||
# Full GitHub Actions URL
|
||||
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
|
||||
|
||||
# Latest run for a PR
|
||||
ods trace --pr 9500
|
||||
|
||||
# Latest run for a specific branch
|
||||
ods trace --branch main
|
||||
|
||||
# Only download admin project traces
|
||||
ods trace --project admin
|
||||
|
||||
# List traces without opening
|
||||
ods trace --list
|
||||
```
|
||||
|
||||
### Testing Changes Locally (Dry Run)
|
||||
|
||||
Both `run-ci` and `cherry-pick` support `--dry-run` to test without making remote changes:
|
||||
|
||||
@@ -55,6 +55,7 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewLatestStableTagCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
cmd.AddCommand(NewTraceCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
556
tools/ods/cmd/trace.go
Normal file
556
tools/ods/cmd/trace.go
Normal file
@@ -0,0 +1,556 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/git"
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/tui"
|
||||
)
|
||||
|
||||
const playwrightWorkflow = "Run Playwright Tests"
|
||||
|
||||
// TraceOptions holds options for the trace command
|
||||
type TraceOptions struct {
|
||||
Branch string
|
||||
PR string
|
||||
Project string
|
||||
List bool
|
||||
NoOpen bool
|
||||
}
|
||||
|
||||
// traceInfo describes a single trace.zip found in the downloaded artifacts.
|
||||
type traceInfo struct {
|
||||
Path string // absolute path to trace.zip
|
||||
Project string // project group extracted from artifact dir (e.g. "admin", "admin-shard-1")
|
||||
TestDir string // test directory name (human-readable-ish)
|
||||
}
|
||||
|
||||
// NewTraceCommand creates a new trace command
|
||||
func NewTraceCommand() *cobra.Command {
|
||||
opts := &TraceOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "trace [run-id-or-url]",
|
||||
Short: "Download and view Playwright traces from GitHub Actions",
|
||||
Long: `Download Playwright trace artifacts from a GitHub Actions run and open them
|
||||
with 'playwright show-trace'.
|
||||
|
||||
The run can be specified as:
|
||||
- A GitHub Actions run ID (numeric)
|
||||
- A full GitHub Actions run URL
|
||||
- Omitted, to find the latest Playwright run for the current branch
|
||||
|
||||
You can also look up the latest run by branch name or PR number.
|
||||
|
||||
Examples:
|
||||
ods trace # latest run for current branch
|
||||
ods trace 12345678 # specific run ID
|
||||
ods trace https://github.com/onyx-dot-app/onyx/actions/runs/12345678
|
||||
ods trace --pr 9500 # latest run for PR #9500
|
||||
ods trace --branch main # latest run for main branch
|
||||
ods trace --project admin # only download admin project traces
|
||||
ods trace --list # list available traces without opening`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runTrace(args, opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "Find latest run for this branch")
|
||||
cmd.Flags().StringVar(&opts.PR, "pr", "", "Find latest run for this PR number")
|
||||
cmd.Flags().StringVarP(&opts.Project, "project", "p", "", "Filter to a specific project (admin, exclusive, lite)")
|
||||
cmd.Flags().BoolVarP(&opts.List, "list", "l", false, "List available traces without opening")
|
||||
cmd.Flags().BoolVar(&opts.NoOpen, "no-open", false, "Download traces but don't open them")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ghRun represents a GitHub Actions workflow run from `gh run list`
|
||||
type ghRun struct {
|
||||
DatabaseID int64 `json:"databaseId"`
|
||||
Status string `json:"status"`
|
||||
Conclusion string `json:"conclusion"`
|
||||
HeadBranch string `json:"headBranch"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
func runTrace(args []string, opts *TraceOptions) {
|
||||
git.CheckGitHubCLI()
|
||||
|
||||
runID, err := resolveRunID(args, opts)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to resolve run: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("Using run ID: %s", runID)
|
||||
|
||||
destDir, err := downloadTraceArtifacts(runID, opts.Project)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to download artifacts: %v", err)
|
||||
}
|
||||
|
||||
traces, err := findTraceInfos(destDir, runID)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find traces: %v", err)
|
||||
}
|
||||
|
||||
if len(traces) == 0 {
|
||||
log.Info("No trace files found in the downloaded artifacts.")
|
||||
log.Info("Traces are only generated for failing tests (retain-on-failure).")
|
||||
return
|
||||
}
|
||||
|
||||
projects := groupByProject(traces)
|
||||
|
||||
if opts.List || opts.NoOpen {
|
||||
printTraceList(traces, projects)
|
||||
fmt.Printf("\nTraces downloaded to: %s\n", destDir)
|
||||
return
|
||||
}
|
||||
|
||||
if len(traces) == 1 {
|
||||
openTraces(traces)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
selected := selectTraces(traces, projects)
|
||||
if len(selected) == 0 {
|
||||
return
|
||||
}
|
||||
openTraces(selected)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRunID determines the run ID from the provided arguments and options.
|
||||
func resolveRunID(args []string, opts *TraceOptions) (string, error) {
|
||||
if len(args) == 1 {
|
||||
return parseRunIDFromArg(args[0])
|
||||
}
|
||||
|
||||
if opts.PR != "" {
|
||||
return findLatestRunForPR(opts.PR)
|
||||
}
|
||||
|
||||
branch := opts.Branch
|
||||
if branch == "" {
|
||||
var err error
|
||||
branch, err = git.GetCurrentBranch()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current branch: %w", err)
|
||||
}
|
||||
if branch == "" {
|
||||
return "", fmt.Errorf("detached HEAD; specify a --branch, --pr, or run ID")
|
||||
}
|
||||
log.Infof("Using current branch: %s", branch)
|
||||
}
|
||||
|
||||
return findLatestRunForBranch(branch)
|
||||
}
|
||||
|
||||
var runURLPattern = regexp.MustCompile(`/actions/runs/(\d+)`)
|
||||
|
||||
// parseRunIDFromArg extracts a run ID from either a numeric string or a full URL.
|
||||
func parseRunIDFromArg(arg string) (string, error) {
|
||||
if matched, _ := regexp.MatchString(`^\d+$`, arg); matched {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
matches := runURLPattern.FindStringSubmatch(arg)
|
||||
if matches != nil {
|
||||
return matches[1], nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not parse run ID from %q; expected a numeric ID or GitHub Actions URL", arg)
|
||||
}
|
||||
|
||||
// findLatestRunForBranch finds the most recent Playwright workflow run for a branch.
|
||||
func findLatestRunForBranch(branch string) (string, error) {
|
||||
log.Infof("Looking up latest Playwright run for branch: %s", branch)
|
||||
|
||||
cmd := exec.Command("gh", "run", "list",
|
||||
"--workflow", playwrightWorkflow,
|
||||
"--branch", branch,
|
||||
"--limit", "1",
|
||||
"--json", "databaseId,status,conclusion,headBranch,url",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", ghError(err, "gh run list failed")
|
||||
}
|
||||
|
||||
var runs []ghRun
|
||||
if err := json.Unmarshal(output, &runs); err != nil {
|
||||
return "", fmt.Errorf("failed to parse run list: %w", err)
|
||||
}
|
||||
|
||||
if len(runs) == 0 {
|
||||
return "", fmt.Errorf("no Playwright runs found for branch %q", branch)
|
||||
}
|
||||
|
||||
run := runs[0]
|
||||
log.Infof("Found run: %s (status: %s, conclusion: %s)", run.URL, run.Status, run.Conclusion)
|
||||
return fmt.Sprintf("%d", run.DatabaseID), nil
|
||||
}
|
||||
|
||||
// findLatestRunForPR finds the most recent Playwright workflow run for a PR.
|
||||
func findLatestRunForPR(prNumber string) (string, error) {
|
||||
log.Infof("Looking up branch for PR #%s", prNumber)
|
||||
|
||||
cmd := exec.Command("gh", "pr", "view", prNumber,
|
||||
"--json", "headRefName",
|
||||
"--jq", ".headRefName",
|
||||
)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", ghError(err, "gh pr view failed")
|
||||
}
|
||||
|
||||
branch := strings.TrimSpace(string(output))
|
||||
if branch == "" {
|
||||
return "", fmt.Errorf("could not determine branch for PR #%s", prNumber)
|
||||
}
|
||||
|
||||
log.Infof("PR #%s is on branch: %s", prNumber, branch)
|
||||
return findLatestRunForBranch(branch)
|
||||
}
|
||||
|
||||
// downloadTraceArtifacts downloads playwright trace artifacts for a run.
|
||||
// Returns the path to the download directory.
|
||||
func downloadTraceArtifacts(runID string, project string) (string, error) {
|
||||
cacheKey := runID
|
||||
if project != "" {
|
||||
cacheKey = runID + "-" + project
|
||||
}
|
||||
destDir := filepath.Join(os.TempDir(), "ods-traces", cacheKey)
|
||||
|
||||
// Reuse a previous download if traces exist
|
||||
if info, err := os.Stat(destDir); err == nil && info.IsDir() {
|
||||
traces, _ := findTraces(destDir)
|
||||
if len(traces) > 0 {
|
||||
log.Infof("Using cached download at %s", destDir)
|
||||
return destDir, nil
|
||||
}
|
||||
_ = os.RemoveAll(destDir)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory %s: %w", destDir, err)
|
||||
}
|
||||
|
||||
ghArgs := []string{"run", "download", runID, "--dir", destDir}
|
||||
|
||||
if project != "" {
|
||||
ghArgs = append(ghArgs, "--pattern", fmt.Sprintf("playwright-test-results-%s-*", project))
|
||||
} else {
|
||||
ghArgs = append(ghArgs, "--pattern", "playwright-test-results-*")
|
||||
}
|
||||
|
||||
log.Infof("Downloading trace artifacts...")
|
||||
log.Debugf("Running: gh %s", strings.Join(ghArgs, " "))
|
||||
|
||||
cmd := exec.Command("gh", ghArgs...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
_ = os.RemoveAll(destDir)
|
||||
return "", fmt.Errorf("gh run download failed: %w\nMake sure the run ID is correct and the artifacts haven't expired (30 day retention)", err)
|
||||
}
|
||||
|
||||
return destDir, nil
|
||||
}
|
||||
|
||||
// findTraces recursively finds all trace.zip files under a directory.
|
||||
func findTraces(root string) ([]string, error) {
|
||||
var traces []string
|
||||
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && info.Name() == "trace.zip" {
|
||||
traces = append(traces, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return traces, err
|
||||
}
|
||||
|
||||
// findTraceInfos walks the download directory and returns structured trace info.
|
||||
// Expects: destDir/{artifact-dir}/{test-dir}/trace.zip
|
||||
func findTraceInfos(destDir, runID string) ([]traceInfo, error) {
|
||||
var traces []traceInfo
|
||||
err := filepath.Walk(destDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() || info.Name() != "trace.zip" {
|
||||
return nil
|
||||
}
|
||||
|
||||
rel, _ := filepath.Rel(destDir, path)
|
||||
parts := strings.SplitN(rel, string(filepath.Separator), 3)
|
||||
|
||||
artifactDir := ""
|
||||
testDir := filepath.Base(filepath.Dir(path))
|
||||
if len(parts) >= 2 {
|
||||
artifactDir = parts[0]
|
||||
testDir = parts[1]
|
||||
}
|
||||
|
||||
traces = append(traces, traceInfo{
|
||||
Path: path,
|
||||
Project: extractProject(artifactDir, runID),
|
||||
TestDir: testDir,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
sort.Slice(traces, func(i, j int) bool {
|
||||
pi, pj := projectSortKey(traces[i].Project), projectSortKey(traces[j].Project)
|
||||
if pi != pj {
|
||||
return pi < pj
|
||||
}
|
||||
return traces[i].TestDir < traces[j].TestDir
|
||||
})
|
||||
|
||||
return traces, err
|
||||
}
|
||||
|
||||
// extractProject derives a project group from an artifact directory name.
|
||||
// e.g. "playwright-test-results-admin-12345" -> "admin"
|
||||
//
|
||||
// "playwright-test-results-admin-shard-1-12345" -> "admin-shard-1"
|
||||
func extractProject(artifactDir, runID string) string {
|
||||
name := strings.TrimPrefix(artifactDir, "playwright-test-results-")
|
||||
name = strings.TrimSuffix(name, "-"+runID)
|
||||
if name == "" {
|
||||
return artifactDir
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// projectSortKey returns a sort-friendly key that orders admin < exclusive < lite.
|
||||
func projectSortKey(project string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(project, "admin"):
|
||||
return "0-" + project
|
||||
case strings.HasPrefix(project, "exclusive"):
|
||||
return "1-" + project
|
||||
case strings.HasPrefix(project, "lite"):
|
||||
return "2-" + project
|
||||
default:
|
||||
return "3-" + project
|
||||
}
|
||||
}
|
||||
|
||||
// groupByProject returns an ordered list of unique project names found in traces.
|
||||
func groupByProject(traces []traceInfo) []string {
|
||||
seen := map[string]bool{}
|
||||
var projects []string
|
||||
for _, t := range traces {
|
||||
if !seen[t.Project] {
|
||||
seen[t.Project] = true
|
||||
projects = append(projects, t.Project)
|
||||
}
|
||||
}
|
||||
sort.Slice(projects, func(i, j int) bool {
|
||||
return projectSortKey(projects[i]) < projectSortKey(projects[j])
|
||||
})
|
||||
return projects
|
||||
}
|
||||
|
||||
// printTraceList displays traces grouped by project.
|
||||
func printTraceList(traces []traceInfo, projects []string) {
|
||||
fmt.Printf("\nFound %d trace(s) across %d project(s):\n", len(traces), len(projects))
|
||||
|
||||
idx := 1
|
||||
for _, proj := range projects {
|
||||
count := 0
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
count++
|
||||
}
|
||||
}
|
||||
fmt.Printf("\n %s (%d):\n", proj, count)
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
fmt.Printf(" [%2d] %s\n", idx, t.TestDir)
|
||||
idx++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// selectTraces tries the TUI picker first, falling back to a plain-text
|
||||
// prompt when the terminal cannot be initialised (e.g. piped output).
|
||||
func selectTraces(traces []traceInfo, projects []string) []traceInfo {
|
||||
// Build picker groups in the same order as the sorted traces slice.
|
||||
var groups []tui.PickerGroup
|
||||
for _, proj := range projects {
|
||||
var items []string
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
items = append(items, t.TestDir)
|
||||
}
|
||||
}
|
||||
groups = append(groups, tui.PickerGroup{Label: proj, Items: items})
|
||||
}
|
||||
|
||||
indices, err := tui.Pick(groups)
|
||||
if err != nil {
|
||||
// Terminal not available — fall back to text prompt
|
||||
log.Debugf("TUI picker unavailable: %v", err)
|
||||
printTraceList(traces, projects)
|
||||
return promptTraceSelection(traces, projects)
|
||||
}
|
||||
if indices == nil {
|
||||
return nil // user cancelled
|
||||
}
|
||||
|
||||
selected := make([]traceInfo, len(indices))
|
||||
for i, idx := range indices {
|
||||
selected[i] = traces[idx]
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
// promptTraceSelection asks the user which traces to open via plain text.
|
||||
// Accepts numbers (1,3,5), ranges (1-5), "all", or a project name.
|
||||
func promptTraceSelection(traces []traceInfo, projects []string) []traceInfo {
|
||||
fmt.Printf("\nOpen which traces? (e.g. 1,3,5 | 1-5 | all | %s): ", strings.Join(projects, " | "))
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read input: %v", err)
|
||||
}
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
if input == "" || strings.EqualFold(input, "all") {
|
||||
return traces
|
||||
}
|
||||
|
||||
// Check if input matches a project name
|
||||
for _, proj := range projects {
|
||||
if strings.EqualFold(input, proj) {
|
||||
var selected []traceInfo
|
||||
for _, t := range traces {
|
||||
if t.Project == proj {
|
||||
selected = append(selected, t)
|
||||
}
|
||||
}
|
||||
return selected
|
||||
}
|
||||
}
|
||||
|
||||
// Parse as number/range selection
|
||||
indices := parseTraceSelection(input, len(traces))
|
||||
if len(indices) == 0 {
|
||||
log.Warn("No valid selection; opening all traces")
|
||||
return traces
|
||||
}
|
||||
|
||||
selected := make([]traceInfo, len(indices))
|
||||
for i, idx := range indices {
|
||||
selected[i] = traces[idx]
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
// parseTraceSelection parses a comma-separated list of numbers and ranges into
|
||||
// 0-based indices. Input is 1-indexed (matches display). Out-of-range values
|
||||
// are silently ignored.
|
||||
func parseTraceSelection(input string, max int) []int {
|
||||
var result []int
|
||||
seen := map[int]bool{}
|
||||
|
||||
for _, part := range strings.Split(input, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if idx := strings.Index(part, "-"); idx > 0 {
|
||||
lo, err1 := strconv.Atoi(strings.TrimSpace(part[:idx]))
|
||||
hi, err2 := strconv.Atoi(strings.TrimSpace(part[idx+1:]))
|
||||
if err1 != nil || err2 != nil {
|
||||
continue
|
||||
}
|
||||
for i := lo; i <= hi; i++ {
|
||||
zi := i - 1
|
||||
if zi >= 0 && zi < max && !seen[zi] {
|
||||
result = append(result, zi)
|
||||
seen[zi] = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
n, err := strconv.Atoi(part)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
zi := n - 1
|
||||
if zi >= 0 && zi < max && !seen[zi] {
|
||||
result = append(result, zi)
|
||||
seen[zi] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// openTraces opens the selected traces with playwright show-trace,
|
||||
// running npx from the web/ directory to use the project's Playwright version.
|
||||
func openTraces(traces []traceInfo) {
|
||||
tracePaths := make([]string, len(traces))
|
||||
for i, t := range traces {
|
||||
tracePaths[i] = t.Path
|
||||
}
|
||||
|
||||
args := append([]string{"playwright", "show-trace"}, tracePaths...)
|
||||
|
||||
log.Infof("Opening %d trace(s) with playwright show-trace...", len(traces))
|
||||
cmd := exec.Command("npx", args...)
|
||||
|
||||
// Run from web/ to pick up the locally-installed Playwright version
|
||||
if root, err := paths.GitRoot(); err == nil {
|
||||
cmd.Dir = filepath.Join(root, "web")
|
||||
}
|
||||
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
// Normal exit (e.g. user closed the window) — just log and return
|
||||
// so the picker loop can continue.
|
||||
log.Debugf("playwright exited with code %d", exitErr.ExitCode())
|
||||
return
|
||||
}
|
||||
log.Errorf("playwright show-trace failed: %v\nMake sure Playwright is installed (npx playwright install)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ghError wraps a gh CLI error with stderr output.
|
||||
func ghError(err error, msg string) error {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return fmt.Errorf("%s: %w: %s", msg, err, string(exitErr.Stderr))
|
||||
}
|
||||
return fmt.Errorf("%s: %w", msg, err)
|
||||
}
|
||||
@@ -3,13 +3,19 @@ module github.com/onyx-dot-app/onyx/tools/ods
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/gdamore/tcell/v2 v2.13.8
|
||||
github.com/jmelahman/tag v0.5.2
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/sirupsen/logrus v1.9.4
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/pflag v1.0.10
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/term v0.41.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,30 +1,68 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
|
||||
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
||||
github.com/gdamore/tcell/v2 v2.13.8 h1:Mys/Kl5wfC/GcC5Cx4C2BIQH9dbnhnkPgS9/wF3RlfU=
|
||||
github.com/gdamore/tcell/v2 v2.13.8/go.mod h1:+Wfe208WDdB7INEtCsNrAN6O2m+wsTPk1RAovjaILlo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jmelahman/tag v0.5.2 h1:g6A/aHehu5tkA31mPoDsXBNr1FigZ9A82Y8WVgb/WsM=
|
||||
github.com/jmelahman/tag v0.5.2/go.mod h1:qmuqk19B1BKkpcg3kn7l/Eey+UqucLxgOWkteUGiG4Q=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
|
||||
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
419
tools/ods/internal/tui/picker.go
Normal file
419
tools/ods/internal/tui/picker.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
)
|
||||
|
||||
// PickerGroup represents a labelled group of selectable items.
|
||||
type PickerGroup struct {
|
||||
Label string
|
||||
Items []string
|
||||
}
|
||||
|
||||
// entry is a single row in the picker (either a group header or an item).
|
||||
type entry struct {
|
||||
label string
|
||||
isHeader bool
|
||||
selected bool
|
||||
groupIdx int
|
||||
flatIdx int // index across all items (ignoring headers), -1 for headers
|
||||
}
|
||||
|
||||
// Pick shows a full-screen grouped multi-select picker.
|
||||
// All items start deselected. Returns the flat indices of selected items
|
||||
// (0-based, spanning all groups in order). Returns nil if cancelled.
|
||||
// Returns a non-nil error if the terminal cannot be initialised, in which
|
||||
// case the caller should fall back to a simpler prompt.
|
||||
func Pick(groups []PickerGroup) ([]int, error) {
|
||||
screen, err := tcell.NewScreen()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := screen.Init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer screen.Fini()
|
||||
|
||||
entries := buildEntries(groups)
|
||||
totalItems := countItems(entries)
|
||||
cursor := firstSelectableIndex(entries)
|
||||
offset := 0
|
||||
|
||||
for {
|
||||
w, h := screen.Size()
|
||||
selectedCount := countSelected(entries)
|
||||
|
||||
drawPicker(screen, entries, groups, cursor, offset, w, h, selectedCount, totalItems)
|
||||
screen.Show()
|
||||
|
||||
ev := screen.PollEvent()
|
||||
switch ev := ev.(type) {
|
||||
case *tcell.EventResize:
|
||||
screen.Sync()
|
||||
case *tcell.EventKey:
|
||||
switch action := keyAction(ev); action {
|
||||
case actionQuit:
|
||||
return nil, nil
|
||||
case actionConfirm:
|
||||
if countSelected(entries) > 0 {
|
||||
return collectSelected(entries), nil
|
||||
}
|
||||
case actionUp:
|
||||
if cursor > 0 {
|
||||
cursor--
|
||||
}
|
||||
case actionDown:
|
||||
if cursor < len(entries)-1 {
|
||||
cursor++
|
||||
}
|
||||
case actionTop:
|
||||
cursor = 0
|
||||
case actionBottom:
|
||||
if len(entries) == 0 {
|
||||
cursor = 0
|
||||
} else {
|
||||
cursor = len(entries) - 1
|
||||
}
|
||||
case actionPageUp:
|
||||
listHeight := h - headerLines - footerLines
|
||||
cursor -= listHeight
|
||||
if cursor < 0 {
|
||||
cursor = 0
|
||||
}
|
||||
case actionPageDown:
|
||||
listHeight := h - headerLines - footerLines
|
||||
cursor += listHeight
|
||||
if cursor >= len(entries) {
|
||||
cursor = len(entries) - 1
|
||||
}
|
||||
case actionToggle:
|
||||
toggleAtCursor(entries, cursor)
|
||||
case actionAll:
|
||||
setAll(entries, true)
|
||||
case actionNone:
|
||||
setAll(entries, false)
|
||||
}
|
||||
|
||||
// Keep the cursor visible
|
||||
listHeight := h - headerLines - footerLines
|
||||
if listHeight < 1 {
|
||||
listHeight = 1
|
||||
}
|
||||
if cursor < offset {
|
||||
offset = cursor
|
||||
}
|
||||
if cursor >= offset+listHeight {
|
||||
offset = cursor - listHeight + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- actions ----------------------------------------------------------------
|
||||
|
||||
type action int
|
||||
|
||||
const (
|
||||
actionNoop action = iota
|
||||
actionQuit
|
||||
actionConfirm
|
||||
actionUp
|
||||
actionDown
|
||||
actionTop
|
||||
actionBottom
|
||||
actionPageUp
|
||||
actionPageDown
|
||||
actionToggle
|
||||
actionAll
|
||||
actionNone
|
||||
)
|
||||
|
||||
func keyAction(ev *tcell.EventKey) action {
|
||||
switch ev.Key() {
|
||||
case tcell.KeyEscape, tcell.KeyCtrlC:
|
||||
return actionQuit
|
||||
case tcell.KeyEnter:
|
||||
return actionConfirm
|
||||
case tcell.KeyUp:
|
||||
return actionUp
|
||||
case tcell.KeyDown:
|
||||
return actionDown
|
||||
case tcell.KeyHome:
|
||||
return actionTop
|
||||
case tcell.KeyEnd:
|
||||
return actionBottom
|
||||
case tcell.KeyPgUp:
|
||||
return actionPageUp
|
||||
case tcell.KeyPgDn:
|
||||
return actionPageDown
|
||||
case tcell.KeyRune:
|
||||
switch ev.Rune() {
|
||||
case 'q':
|
||||
return actionQuit
|
||||
case ' ':
|
||||
return actionToggle
|
||||
case 'j':
|
||||
return actionDown
|
||||
case 'k':
|
||||
return actionUp
|
||||
case 'g':
|
||||
return actionTop
|
||||
case 'G':
|
||||
return actionBottom
|
||||
case 'a':
|
||||
return actionAll
|
||||
case 'n':
|
||||
return actionNone
|
||||
}
|
||||
}
|
||||
return actionNoop
|
||||
}
|
||||
|
||||
// --- data helpers ------------------------------------------------------------
|
||||
|
||||
func buildEntries(groups []PickerGroup) []entry {
|
||||
var entries []entry
|
||||
flat := 0
|
||||
for gi, g := range groups {
|
||||
entries = append(entries, entry{
|
||||
label: g.Label,
|
||||
isHeader: true,
|
||||
groupIdx: gi,
|
||||
flatIdx: -1,
|
||||
})
|
||||
for _, item := range g.Items {
|
||||
entries = append(entries, entry{
|
||||
label: item,
|
||||
isHeader: false,
|
||||
selected: false,
|
||||
groupIdx: gi,
|
||||
flatIdx: flat,
|
||||
})
|
||||
flat++
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func firstSelectableIndex(entries []entry) int {
|
||||
for i, e := range entries {
|
||||
if !e.isHeader {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func countItems(entries []entry) int {
|
||||
n := 0
|
||||
for _, e := range entries {
|
||||
if !e.isHeader {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func countSelected(entries []entry) int {
|
||||
n := 0
|
||||
for _, e := range entries {
|
||||
if !e.isHeader && e.selected {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func collectSelected(entries []entry) []int {
|
||||
var result []int
|
||||
for _, e := range entries {
|
||||
if !e.isHeader && e.selected {
|
||||
result = append(result, e.flatIdx)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toggleAtCursor(entries []entry, cursor int) {
|
||||
if cursor < 0 || cursor >= len(entries) {
|
||||
return
|
||||
}
|
||||
e := entries[cursor]
|
||||
if e.isHeader {
|
||||
// Toggle entire group: if all selected -> deselect all, else select all
|
||||
allSelected := true
|
||||
for _, e2 := range entries {
|
||||
if !e2.isHeader && e2.groupIdx == e.groupIdx && !e2.selected {
|
||||
allSelected = false
|
||||
break
|
||||
}
|
||||
}
|
||||
for i := range entries {
|
||||
if !entries[i].isHeader && entries[i].groupIdx == e.groupIdx {
|
||||
entries[i].selected = !allSelected
|
||||
}
|
||||
}
|
||||
} else {
|
||||
entries[cursor].selected = !entries[cursor].selected
|
||||
}
|
||||
}
|
||||
|
||||
func setAll(entries []entry, selected bool) {
|
||||
for i := range entries {
|
||||
if !entries[i].isHeader {
|
||||
entries[i].selected = selected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- drawing ----------------------------------------------------------------
|
||||
|
||||
const (
|
||||
headerLines = 2 // title + blank line
|
||||
footerLines = 2 // blank line + keybinds
|
||||
)
|
||||
|
||||
var (
|
||||
styleDefault = tcell.StyleDefault
|
||||
styleTitle = tcell.StyleDefault.Bold(true)
|
||||
styleGroup = tcell.StyleDefault.Bold(true).Foreground(tcell.ColorTeal)
|
||||
styleGroupCur = tcell.StyleDefault.Bold(true).Foreground(tcell.ColorTeal).Reverse(true)
|
||||
styleCheck = tcell.StyleDefault.Foreground(tcell.ColorGreen).Bold(true)
|
||||
styleUncheck = tcell.StyleDefault.Dim(true)
|
||||
styleItem = tcell.StyleDefault
|
||||
styleItemCur = tcell.StyleDefault.Bold(true).Underline(true)
|
||||
styleCheckCur = tcell.StyleDefault.Foreground(tcell.ColorGreen).Bold(true).Underline(true)
|
||||
styleUncheckCur = tcell.StyleDefault.Dim(true).Underline(true)
|
||||
styleFooter = tcell.StyleDefault.Dim(true)
|
||||
)
|
||||
|
||||
func drawPicker(
|
||||
screen tcell.Screen,
|
||||
entries []entry,
|
||||
groups []PickerGroup,
|
||||
cursor, offset, w, h, selectedCount, totalItems int,
|
||||
) {
|
||||
screen.Clear()
|
||||
|
||||
// Title
|
||||
title := fmt.Sprintf(" Select traces to open (%d/%d selected)", selectedCount, totalItems)
|
||||
drawLine(screen, 0, 0, w, title, styleTitle)
|
||||
|
||||
// List area
|
||||
listHeight := h - headerLines - footerLines
|
||||
if listHeight < 1 {
|
||||
listHeight = 1
|
||||
}
|
||||
|
||||
for i := 0; i < listHeight; i++ {
|
||||
ei := offset + i
|
||||
if ei >= len(entries) {
|
||||
break
|
||||
}
|
||||
y := headerLines + i
|
||||
renderEntry(screen, entries, groups, ei, cursor, w, y)
|
||||
}
|
||||
|
||||
// Scrollbar hint
|
||||
if len(entries) > listHeight {
|
||||
drawScrollbar(screen, w-1, headerLines, listHeight, offset, len(entries))
|
||||
}
|
||||
|
||||
// Footer
|
||||
footerY := h - 1
|
||||
footer := " ↑/↓ move space toggle a all n none enter open q/esc quit"
|
||||
drawLine(screen, 0, footerY, w, footer, styleFooter)
|
||||
}
|
||||
|
||||
func renderEntry(screen tcell.Screen, entries []entry, groups []PickerGroup, ei, cursor, w, y int) {
|
||||
e := entries[ei]
|
||||
isCursor := ei == cursor
|
||||
|
||||
if e.isHeader {
|
||||
groupSelected := 0
|
||||
groupTotal := 0
|
||||
for _, e2 := range entries {
|
||||
if !e2.isHeader && e2.groupIdx == e.groupIdx {
|
||||
groupTotal++
|
||||
if e2.selected {
|
||||
groupSelected++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
label := fmt.Sprintf(" %s (%d/%d)", e.label, groupSelected, groupTotal)
|
||||
style := styleGroup
|
||||
if isCursor {
|
||||
style = styleGroupCur
|
||||
}
|
||||
drawLine(screen, 0, y, w, label, style)
|
||||
return
|
||||
}
|
||||
|
||||
// Item row: " [x] label" or " > [x] label"
|
||||
prefix := " "
|
||||
if isCursor {
|
||||
prefix = " > "
|
||||
}
|
||||
|
||||
check := "[ ]"
|
||||
cStyle := styleUncheck
|
||||
iStyle := styleItem
|
||||
if isCursor {
|
||||
cStyle = styleUncheckCur
|
||||
iStyle = styleItemCur
|
||||
}
|
||||
if e.selected {
|
||||
check = "[x]"
|
||||
cStyle = styleCheck
|
||||
if isCursor {
|
||||
cStyle = styleCheckCur
|
||||
}
|
||||
}
|
||||
|
||||
x := drawStr(screen, 0, y, w, prefix, iStyle)
|
||||
x = drawStr(screen, x, y, w, check, cStyle)
|
||||
drawStr(screen, x, y, w, " "+e.label, iStyle)
|
||||
}
|
||||
|
||||
func drawScrollbar(screen tcell.Screen, x, top, height, offset, total int) {
|
||||
if total <= height || height < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
thumbSize := max(1, height*height/total)
|
||||
thumbPos := top + offset*height/total
|
||||
|
||||
for y := top; y < top+height; y++ {
|
||||
ch := '│'
|
||||
style := styleDefault.Dim(true)
|
||||
if y >= thumbPos && y < thumbPos+thumbSize {
|
||||
ch = '┃'
|
||||
style = styleDefault
|
||||
}
|
||||
screen.SetContent(x, y, ch, nil, style)
|
||||
}
|
||||
}
|
||||
|
||||
// drawLine fills an entire row starting at x=startX, padding to width w.
|
||||
func drawLine(screen tcell.Screen, startX, y, w int, s string, style tcell.Style) {
|
||||
x := drawStr(screen, startX, y, w, s, style)
|
||||
// Clear the rest of the line
|
||||
for ; x < w; x++ {
|
||||
screen.SetContent(x, y, ' ', nil, style)
|
||||
}
|
||||
}
|
||||
|
||||
// drawStr writes a string at (x, y) up to maxX and returns the next x position.
|
||||
func drawStr(screen tcell.Screen, x, y, maxX int, s string, style tcell.Style) int {
|
||||
for _, ch := range s {
|
||||
if x >= maxX {
|
||||
break
|
||||
}
|
||||
screen.SetContent(x, y, ch, nil, style)
|
||||
x++
|
||||
}
|
||||
return x
|
||||
}
|
||||
22
web/lib/opal/src/icons/bifrost.tsx
Normal file
22
web/lib/opal/src/icons/bifrost.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 37 46"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={cn(className, "text-[#33C19E] dark:text-white")}
|
||||
{...props}
|
||||
>
|
||||
<title>Bifrost</title>
|
||||
<path
|
||||
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgBifrost;
|
||||
@@ -24,6 +24,7 @@ export { default as SvgAzure } from "@opal/icons/azure";
|
||||
export { default as SvgBarChart } from "@opal/icons/bar-chart";
|
||||
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
|
||||
export { default as SvgBell } from "@opal/icons/bell";
|
||||
export { default as SvgBifrost } from "@opal/icons/bifrost";
|
||||
export { default as SvgBlocks } from "@opal/icons/blocks";
|
||||
export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBookmark } from "@opal/icons/bookmark";
|
||||
|
||||
24
web/package-lock.json
generated
24
web/package-lock.json
generated
@@ -7901,7 +7901,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/anymatch/node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
@@ -10701,7 +10703,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/handlebars": {
|
||||
"version": "4.7.8",
|
||||
"version": "4.7.9",
|
||||
"resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.9.tgz",
|
||||
"integrity": "sha512-4E71E0rpOaQuJR2A3xDZ+GM1HyWYv1clR58tC8emQNeQe3RH7MAzSbat+V0wG78LQBo6m6bzSG/L4pBuCsgnUQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
@@ -12555,7 +12559,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/jest-util/node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
@@ -13881,7 +13887,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/micromatch/node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
@@ -15001,7 +15009,9 @@
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/picomatch": {
|
||||
"version": "4.0.3",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz",
|
||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
@@ -15889,7 +15899,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/readdirp/node_modules/picomatch": {
|
||||
"version": "2.3.1",
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
|
||||
@@ -30,8 +30,11 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
|
||||
import { useBillingInformation } from "@/hooks/useBillingInformation";
|
||||
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTES.API_KEYS;
|
||||
@@ -44,6 +47,11 @@ function Main() {
|
||||
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const canCreateKeys = useCloudSubscription();
|
||||
const { data: billingData } = useBillingInformation();
|
||||
const isTrialing =
|
||||
billingData !== undefined &&
|
||||
hasActiveSubscription(billingData) &&
|
||||
billingData.status === BillingStatus.TRIALING;
|
||||
|
||||
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
|
||||
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
|
||||
@@ -75,6 +83,16 @@ function Main() {
|
||||
|
||||
const introSection = (
|
||||
<div className="flex flex-col items-start gap-4">
|
||||
{isTrialing && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
close={false}
|
||||
className="w-full"
|
||||
text="Upgrade to a paid plan to create API keys."
|
||||
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
)}
|
||||
<Text as="p">
|
||||
API Keys allow you to access Onyx APIs programmatically.
|
||||
{canCreateKeys
|
||||
@@ -85,23 +103,9 @@ function Main() {
|
||||
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
|
||||
Create API Key
|
||||
</CreateButton>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Text as="p" text04>
|
||||
Upgrade to a paid plan to create API keys.
|
||||
</Text>
|
||||
<Button
|
||||
variant="none"
|
||||
prominence="tertiary"
|
||||
size="2xs"
|
||||
icon={SvgInfo}
|
||||
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
</div>
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
</div>
|
||||
)}
|
||||
) : isTrialing ? (
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
|
||||
|
||||
387
web/src/app/admin/billing/page.test.tsx
Normal file
387
web/src/app/admin/billing/page.test.tsx
Normal file
@@ -0,0 +1,387 @@
|
||||
/**
|
||||
* Tests for BillingPage handleBillingReturn retry logic.
|
||||
*
|
||||
* The retry logic retries claimLicense up to 3 times with 2s backoff
|
||||
* when returning from a Stripe checkout session. This prevents the user
|
||||
* from getting stranded when the Stripe webhook fires concurrently with
|
||||
* the browser redirect and the license isn't ready yet.
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, waitFor } from "@tests/setup/test-utils";
|
||||
import { act } from "@testing-library/react";
|
||||
|
||||
// ---- Stable mock objects (must be named with mock* prefix for jest hoisting) ----
|
||||
// useRouter and useSearchParams must return the SAME reference each call, otherwise
|
||||
// React's useEffect sees them as changed and re-runs the effect on every render.
|
||||
const mockRouter = {
|
||||
replace: jest.fn() as jest.Mock,
|
||||
refresh: jest.fn() as jest.Mock,
|
||||
};
|
||||
const mockSearchParams = {
|
||||
get: jest.fn() as jest.Mock,
|
||||
};
|
||||
const mockClaimLicense = jest.fn() as jest.Mock;
|
||||
const mockRefreshBilling = jest.fn() as jest.Mock;
|
||||
const mockRefreshLicense = jest.fn() as jest.Mock;
|
||||
|
||||
// ---- Mocks ----
|
||||
|
||||
jest.mock("next/navigation", () => ({
|
||||
useRouter: () => mockRouter,
|
||||
useSearchParams: () => mockSearchParams,
|
||||
}));
|
||||
|
||||
jest.mock("@/layouts/settings-layouts", () => ({
|
||||
Root: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="settings-root">{children}</div>
|
||||
),
|
||||
Header: () => <div data-testid="settings-header" />,
|
||||
Body: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="settings-body">{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/layouts/general-layouts", () => ({
|
||||
Section: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@opal/icons", () => ({
|
||||
SvgArrowUpCircle: () => <svg />,
|
||||
SvgWallet: () => <svg />,
|
||||
}));
|
||||
|
||||
jest.mock("./PlansView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="plans-view" />,
|
||||
}));
|
||||
jest.mock("./CheckoutView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="checkout-view" />,
|
||||
}));
|
||||
jest.mock("./BillingDetailsView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="billing-details-view" />,
|
||||
}));
|
||||
jest.mock("./LicenseActivationCard", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="license-activation-card" />,
|
||||
}));
|
||||
|
||||
jest.mock("@/refresh-components/messages/Message", () => ({
|
||||
__esModule: true,
|
||||
default: ({
|
||||
text,
|
||||
description,
|
||||
onClose,
|
||||
}: {
|
||||
text: string;
|
||||
description?: string;
|
||||
onClose?: () => void;
|
||||
}) => (
|
||||
<div data-testid="activating-banner">
|
||||
<span data-testid="activating-banner-text">{text}</span>
|
||||
{description && (
|
||||
<span data-testid="activating-banner-description">{description}</span>
|
||||
)}
|
||||
{onClose && (
|
||||
<button data-testid="activating-banner-close" onClick={onClose}>
|
||||
Close
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/billing", () => ({
|
||||
useBillingInformation: jest.fn(),
|
||||
useLicense: jest.fn(),
|
||||
hasActiveSubscription: jest.fn().mockReturnValue(false),
|
||||
claimLicense: (...args: unknown[]) => mockClaimLicense(...args),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/constants", () => ({
|
||||
NEXT_PUBLIC_CLOUD_ENABLED: false,
|
||||
}));
|
||||
|
||||
// ---- Import after mocks ----
|
||||
import BillingPage from "./page";
|
||||
import { useBillingInformation, useLicense } from "@/lib/billing";
|
||||
|
||||
// ---- Test helpers ----
|
||||
|
||||
function setupHooks() {
|
||||
(useBillingInformation as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
refresh: mockRefreshBilling,
|
||||
});
|
||||
(useLicense as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
refresh: mockRefreshLicense,
|
||||
});
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
describe("BillingPage — handleBillingReturn retry logic", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
jest.useFakeTimers();
|
||||
setupHooks();
|
||||
// Default: no billing-return params
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
// Clear any activating state from prior tests
|
||||
sessionStorage.clear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
test("calls claimLicense once and refreshes on first-attempt success", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_test_123" : null
|
||||
);
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith("cs_test_123");
|
||||
});
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
// URL cleaned up after checkout return
|
||||
expect(mockRouter.replace).toHaveBeenCalledWith("/admin/billing", {
|
||||
scroll: false,
|
||||
});
|
||||
});
|
||||
|
||||
test("retries after first failure and succeeds on second attempt", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_retry_test" : null
|
||||
);
|
||||
mockClaimLicense
|
||||
.mockRejectedValueOnce(new Error("License not ready yet"))
|
||||
.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
// On eventual success, router and billing should be refreshed
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("retries all 3 times then navigates to details even on total failure", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_all_fail" : null
|
||||
);
|
||||
// All 3 attempts fail
|
||||
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
|
||||
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
// User stays on plans view with the activating banner
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("plans-view")).toBeInTheDocument();
|
||||
});
|
||||
// refreshBilling still fires so billing state is up to date
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
// Failure is logged
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Failed to sync license after billing return"),
|
||||
expect.any(Error)
|
||||
);
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("calls claimLicense without session_id on portal_return", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "portal_return" ? "true" : null
|
||||
);
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
|
||||
// No session_id for portal returns — called with undefined
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
|
||||
});
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("does not call claimLicense when no billing-return params present", async () => {
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(mockClaimLicense).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("shows activating banner and sets sessionStorage on 3x retry failure", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_all_fail" : null
|
||||
);
|
||||
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
|
||||
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.getByTestId("activating-banner-text")).toHaveTextContent(
|
||||
"Your license is still activating"
|
||||
);
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).not.toBeNull();
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("banner not rendered when no activating state", async () => {
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("banner shown on mount when sessionStorage key is set and not expired", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush React effects — banner is visible from lazy state init, no timer advancement needed
|
||||
await act(async () => {});
|
||||
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("banner not shown on mount when sessionStorage key is expired", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() - 1000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
test("poll calls claimLicense after 15s and clears banner on success", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
// Poll attempt succeeds
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush effects — banner visible from lazy state init
|
||||
await act(async () => {});
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
|
||||
// Advance past one poll interval (15s)
|
||||
await act(async () => {
|
||||
await jest.advanceTimersByTimeAsync(15_000);
|
||||
});
|
||||
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
expect(mockRefreshLicense).toHaveBeenCalled();
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("close button removes banner and clears sessionStorage", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush effects — banner visible from lazy state init
|
||||
await act(async () => {});
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
|
||||
const closeButton = screen.getByTestId("activating-banner-close");
|
||||
await act(async () => {
|
||||
closeButton.click();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
} from "@/lib/billing";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
|
||||
import PlansView from "./PlansView";
|
||||
import CheckoutView from "./CheckoutView";
|
||||
@@ -24,6 +25,9 @@ import BillingDetailsView from "./BillingDetailsView";
|
||||
import LicenseActivationCard from "./LicenseActivationCard";
|
||||
import "./billing.css";
|
||||
|
||||
// sessionStorage key: value is a unix-ms expiry timestamp
|
||||
const BILLING_ACTIVATING_KEY = "billing_license_activating_until";
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Types
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -105,6 +109,7 @@ export default function BillingPage() {
|
||||
const [transitionType, setTransitionType] = useState<
|
||||
"expand" | "collapse" | "fade"
|
||||
>("fade");
|
||||
const [isActivating, setIsActivating] = useState<boolean>(false);
|
||||
|
||||
const {
|
||||
data: billingData,
|
||||
@@ -155,6 +160,17 @@ export default function BillingPage() {
|
||||
view,
|
||||
]);
|
||||
|
||||
// Read activating state from sessionStorage after mount (avoids SSR hydration mismatch)
|
||||
useEffect(() => {
|
||||
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
|
||||
if (!raw) return;
|
||||
if (Number(raw) > Date.now()) {
|
||||
setIsActivating(true);
|
||||
} else {
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Show license activation card when there's a Stripe error
|
||||
useEffect(() => {
|
||||
if (hasStripeError && !showLicenseActivationInput) {
|
||||
@@ -172,24 +188,96 @@ export default function BillingPage() {
|
||||
|
||||
router.replace("/admin/billing", { scroll: false });
|
||||
|
||||
let cancelled = false;
|
||||
|
||||
const handleBillingReturn = async () => {
|
||||
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
try {
|
||||
// After checkout, exchange session_id for license; after portal, re-sync license
|
||||
await claimLicense(sessionId ?? undefined);
|
||||
refreshLicense();
|
||||
// Refresh the page to update settings (including ee_features_enabled)
|
||||
router.refresh();
|
||||
// Navigate to billing details now that the license is active
|
||||
changeView("details");
|
||||
} catch (error) {
|
||||
console.error("Failed to sync license after billing return:", error);
|
||||
// Retry up to 3 times with 2s backoff. The license may not be available
|
||||
// immediately if the Stripe webhook hasn't finished processing yet
|
||||
// (redirect and webhook fire nearly simultaneously).
|
||||
let lastError: Error | null = null;
|
||||
for (let attempt = 0; attempt < 3; attempt++) {
|
||||
if (cancelled) return;
|
||||
try {
|
||||
// After checkout, exchange session_id for license; after portal, re-sync license
|
||||
await claimLicense(sessionId ?? undefined);
|
||||
if (cancelled) return;
|
||||
refreshLicense();
|
||||
// Refresh the page to update settings (including ee_features_enabled)
|
||||
router.refresh();
|
||||
// Navigate to billing details now that the license is active
|
||||
changeView("details");
|
||||
lastError = null;
|
||||
break;
|
||||
} catch (err) {
|
||||
lastError = err instanceof Error ? err : new Error("Unknown error");
|
||||
if (attempt < 2) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cancelled) return;
|
||||
if (lastError) {
|
||||
console.error(
|
||||
"Failed to sync license after billing return:",
|
||||
lastError
|
||||
);
|
||||
// Show an activating banner on the plans view and keep retrying in the background.
|
||||
sessionStorage.setItem(
|
||||
BILLING_ACTIVATING_KEY,
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
setIsActivating(true);
|
||||
changeView("plans");
|
||||
}
|
||||
}
|
||||
refreshBilling();
|
||||
if (!cancelled) refreshBilling();
|
||||
};
|
||||
handleBillingReturn();
|
||||
}, [searchParams, router, refreshBilling, refreshLicense]);
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
// changeView intentionally omitted: it only calls stable state setters and the
|
||||
// effect runs at most once (when session_id/portal_return params are present).
|
||||
}, [searchParams, router, refreshBilling, refreshLicense]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
// Poll every 15s while activating, up to 2 minutes, to detect when the license arrives.
|
||||
useEffect(() => {
|
||||
if (!isActivating) return;
|
||||
|
||||
let requestInFlight = false;
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
if (requestInFlight) return;
|
||||
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
|
||||
if (!raw || Number(raw) <= Date.now()) {
|
||||
// Expired — stop immediately without waiting for React cleanup
|
||||
clearInterval(intervalId);
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
return;
|
||||
}
|
||||
requestInFlight = true;
|
||||
try {
|
||||
await claimLicense(undefined);
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
refreshLicense();
|
||||
refreshBilling();
|
||||
router.refresh();
|
||||
changeView("details");
|
||||
} catch (err) {
|
||||
// License not ready yet — keep polling. Log so unexpected failures
|
||||
// (network errors, 500s) are distinguishable from expected 404s.
|
||||
console.debug("License activation poll: will retry", err);
|
||||
} finally {
|
||||
requestInFlight = false;
|
||||
}
|
||||
}, 15_000);
|
||||
|
||||
return () => clearInterval(intervalId);
|
||||
}, [isActivating]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const handleRefresh = async () => {
|
||||
await Promise.all([
|
||||
@@ -386,6 +474,22 @@ export default function BillingPage() {
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<div className="flex flex-col items-center gap-6">
|
||||
{isActivating && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
large
|
||||
text="Your license is still activating"
|
||||
description="Your license is being processed. You'll be taken to billing details automatically once confirmed."
|
||||
icon
|
||||
close
|
||||
onClose={() => {
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
}}
|
||||
className="w-full"
|
||||
/>
|
||||
)}
|
||||
{renderContent()}
|
||||
{renderFooter()}
|
||||
</div>
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo } from "react";
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import useSWR from "swr";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
@@ -17,9 +18,16 @@ import {
|
||||
ImageGenerationConfigView,
|
||||
setDefaultImageGenerationConfig,
|
||||
unsetDefaultImageGenerationConfig,
|
||||
deleteImageGenerationConfig,
|
||||
} from "@/lib/configuration/imageConfigurationService";
|
||||
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
export default function ImageGenerationContent() {
|
||||
const {
|
||||
@@ -47,6 +55,11 @@ export default function ImageGenerationContent() {
|
||||
);
|
||||
const [editConfig, setEditConfig] =
|
||||
useState<ImageGenerationConfigView | null>(null);
|
||||
const [disconnectProvider, setDisconnectProvider] =
|
||||
useState<ImageProvider | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
|
||||
const connectedProviderIds = useMemo(() => {
|
||||
return new Set(configs.map((c) => c.image_provider_id));
|
||||
@@ -115,6 +128,29 @@ export default function ImageGenerationContent() {
|
||||
modal.toggle(true);
|
||||
};
|
||||
|
||||
const handleDisconnect = async () => {
|
||||
if (!disconnectProvider) return;
|
||||
try {
|
||||
// If a replacement was selected (not "No Default"), activate it first
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
await setDefaultImageGenerationConfig(replacementProviderId);
|
||||
}
|
||||
|
||||
await deleteImageGenerationConfig(disconnectProvider.image_provider_id);
|
||||
toast.success(`${disconnectProvider.title} disconnected`);
|
||||
refetchConfigs();
|
||||
refetchProviders();
|
||||
} catch (error) {
|
||||
console.error("Failed to disconnect image generation provider:", error);
|
||||
toast.error(
|
||||
error instanceof Error ? error.message : "Failed to disconnect"
|
||||
);
|
||||
} finally {
|
||||
setDisconnectProvider(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleModalSuccess = () => {
|
||||
toast.success("Provider configured successfully");
|
||||
setEditConfig(null);
|
||||
@@ -130,6 +166,36 @@ export default function ImageGenerationContent() {
|
||||
);
|
||||
}
|
||||
|
||||
// Compute replacement options when disconnecting an active provider
|
||||
const isDisconnectingDefault =
|
||||
disconnectProvider &&
|
||||
defaultConfig?.image_provider_id === disconnectProvider.image_provider_id;
|
||||
|
||||
// Group connected replacement models by provider (excluding the model being disconnected)
|
||||
const replacementGroups = useMemo(() => {
|
||||
if (!disconnectProvider) return [];
|
||||
return IMAGE_PROVIDER_GROUPS.map((group) => ({
|
||||
...group,
|
||||
providers: group.providers.filter(
|
||||
(p) =>
|
||||
p.image_provider_id !== disconnectProvider.image_provider_id &&
|
||||
connectedProviderIds.has(p.image_provider_id)
|
||||
),
|
||||
})).filter((g) => g.providers.length > 0);
|
||||
}, [disconnectProvider, connectedProviderIds]);
|
||||
|
||||
const needsReplacement = !!isDisconnectingDefault;
|
||||
const hasReplacements = replacementGroups.length > 0;
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && !replacementProviderId && hasReplacements) {
|
||||
const firstGroup = replacementGroups[0];
|
||||
const firstModel = firstGroup?.providers[0];
|
||||
if (firstModel) setReplacementProviderId(firstModel.image_provider_id);
|
||||
}
|
||||
}, [disconnectProvider]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-6">
|
||||
@@ -175,6 +241,11 @@ export default function ImageGenerationContent() {
|
||||
onSelect={() => handleSelect(provider)}
|
||||
onDeselect={() => handleDeselect(provider)}
|
||||
onEdit={() => handleEdit(provider)}
|
||||
onDisconnect={
|
||||
getStatus(provider) !== "disconnected"
|
||||
? () => setDisconnectProvider(provider)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@@ -182,6 +253,105 @@ export default function ImageGenerationContent() {
|
||||
))}
|
||||
</div>
|
||||
|
||||
{disconnectProvider && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectProvider.title}`}
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={() => {
|
||||
setDisconnectProvider(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={() => void handleDisconnect()}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> is currently the default
|
||||
image generation model. Session history will be preserved.
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" text04>
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => setReplacementProviderId(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement model" />
|
||||
<InputSelect.Content>
|
||||
{replacementGroups.map((group) => (
|
||||
<InputSelect.Group key={group.name}>
|
||||
<InputSelect.Label>{group.name}</InputSelect.Label>
|
||||
{group.providers.map((p) => (
|
||||
<InputSelect.Item
|
||||
key={p.image_provider_id}
|
||||
value={p.image_provider_id}
|
||||
icon={() => (
|
||||
<ProviderIcon
|
||||
provider={p.provider_name}
|
||||
size={16}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
{p.title}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Group>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item
|
||||
value={NO_DEFAULT_VALUE}
|
||||
icon={SvgSlash}
|
||||
>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03">
|
||||
{" "}
|
||||
(Disable Image Generation)
|
||||
</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> is currently the default
|
||||
image generation model.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Connect another provider to continue using image generation.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> models will no longer be used
|
||||
to generate images.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
|
||||
{activeProvider && (
|
||||
<modal.Provider>
|
||||
<ImageGenerationConnectionModal
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
BedrockModelResponse,
|
||||
LMStudioModelResponse,
|
||||
LiteLLMProxyModelResponse,
|
||||
BifrostModelResponse,
|
||||
ModelConfiguration,
|
||||
LLMProviderName,
|
||||
BedrockFetchParams,
|
||||
@@ -30,8 +31,9 @@ import {
|
||||
LMStudioFetchParams,
|
||||
OpenRouterFetchParams,
|
||||
LiteLLMProxyFetchParams,
|
||||
BifrostFetchParams,
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgOpenrouter } from "@opal/icons";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
@@ -41,6 +43,7 @@ export const AGGREGATOR_PROVIDERS = new Set([
|
||||
"ollama_chat",
|
||||
"lm_studio",
|
||||
"litellm_proxy",
|
||||
"bifrost",
|
||||
"vertex_ai",
|
||||
]);
|
||||
|
||||
@@ -78,6 +81,7 @@ export const getProviderIcon = (
|
||||
bedrock_converse: SvgAws,
|
||||
openrouter: SvgOpenrouter,
|
||||
litellm_proxy: LiteLLMIcon,
|
||||
bifrost: SvgBifrost,
|
||||
vertex_ai: GeminiIcon,
|
||||
};
|
||||
|
||||
@@ -263,8 +267,11 @@ export const fetchOpenRouterModels = async (
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse OpenRouter model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
@@ -319,8 +326,11 @@ export const fetchLMStudioModels = async (
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse LM Studio model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
@@ -343,6 +353,64 @@ export const fetchLMStudioModels = async (
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches Bifrost models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchBifrostModels = async (
|
||||
params: BifrostFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/admin/llm/bifrost/available-models", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: params.api_key,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse Bifrost model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: BifrostModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: modelData.supports_reasoning,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches LiteLLM Proxy models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
@@ -456,6 +524,13 @@ export const fetchModels = async (
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
case LLMProviderName.BIFROST:
|
||||
return fetchBifrostModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
default:
|
||||
return { models: [], error: `Unknown provider: ${providerName}` };
|
||||
}
|
||||
@@ -469,6 +544,7 @@ export function canProviderFetchModels(providerName?: string) {
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
case LLMProviderName.OPENROUTER:
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
case LLMProviderName.BIFROST:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -1,32 +1,25 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState, useReducer } from "react";
|
||||
import { useEffect, useMemo, useState, useReducer } from "react";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgCheckSquare,
|
||||
SvgEdit,
|
||||
SvgGlobe,
|
||||
SvgOnyxLogo,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
|
||||
const route = ADMIN_ROUTES.WEB_SEARCH;
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import {
|
||||
SEARCH_PROVIDERS_URL,
|
||||
SEARCH_PROVIDER_DETAILS,
|
||||
@@ -58,6 +51,10 @@ import {
|
||||
} from "@/app/admin/configuration/web-search/WebProviderModalReducer";
|
||||
import { connectProviderFlow } from "@/app/admin/configuration/web-search/connectProviderFlow";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
const route = ADMIN_ROUTES.WEB_SEARCH;
|
||||
|
||||
interface WebSearchProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
@@ -76,27 +73,151 @@ interface WebContentProviderView {
|
||||
has_api_key: boolean;
|
||||
}
|
||||
|
||||
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
|
||||
isHovered: boolean;
|
||||
onMouseEnter: () => void;
|
||||
onMouseLeave: () => void;
|
||||
children: React.ReactNode;
|
||||
interface DisconnectTargetState {
|
||||
id: number;
|
||||
label: string;
|
||||
category: "search" | "content";
|
||||
providerType: string;
|
||||
}
|
||||
|
||||
function HoverIconButton({
|
||||
isHovered,
|
||||
onMouseEnter,
|
||||
onMouseLeave,
|
||||
children,
|
||||
...buttonProps
|
||||
}: HoverIconButtonProps) {
|
||||
function WebSearchDisconnectModal({
|
||||
disconnectTarget,
|
||||
searchProviders,
|
||||
contentProviders,
|
||||
replacementProviderId,
|
||||
onReplacementChange,
|
||||
onClose,
|
||||
onDisconnect,
|
||||
}: {
|
||||
disconnectTarget: DisconnectTargetState;
|
||||
searchProviders: WebSearchProviderView[];
|
||||
contentProviders: WebContentProviderView[];
|
||||
replacementProviderId: string | null;
|
||||
onReplacementChange: (id: string | null) => void;
|
||||
onClose: () => void;
|
||||
onDisconnect: () => void;
|
||||
}) {
|
||||
const isSearch = disconnectTarget.category === "search";
|
||||
|
||||
// Determine if the target is currently the active/selected provider
|
||||
const isActive = isSearch
|
||||
? searchProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
|
||||
false
|
||||
: contentProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
|
||||
false;
|
||||
|
||||
// Find other configured providers as replacements
|
||||
const replacementOptions = isSearch
|
||||
? searchProviders.filter(
|
||||
(p) => p.id !== disconnectTarget.id && p.id > 0 && p.has_api_key
|
||||
)
|
||||
: contentProviders.filter(
|
||||
(p) =>
|
||||
p.id !== disconnectTarget.id &&
|
||||
p.provider_type !== "onyx_web_crawler" &&
|
||||
p.id > 0 &&
|
||||
p.has_api_key
|
||||
);
|
||||
|
||||
const needsReplacement = isActive;
|
||||
const hasReplacements = replacementOptions.length > 0;
|
||||
|
||||
const getLabel = (p: { name: string; provider_type: string }) => {
|
||||
if (isSearch) {
|
||||
const details =
|
||||
SEARCH_PROVIDER_DETAILS[p.provider_type as WebSearchProviderType];
|
||||
return details?.label ?? p.name ?? p.provider_type;
|
||||
}
|
||||
const details = CONTENT_PROVIDER_DETAILS[p.provider_type];
|
||||
return details?.label ?? p.name ?? p.provider_type;
|
||||
};
|
||||
|
||||
const categoryLabel = isSearch ? "search engine" : "web crawler";
|
||||
const featureLabel = isSearch ? "web search" : "web crawling";
|
||||
const disableLabel = isSearch ? "Disable Web Search" : "Disable Web Crawling";
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && hasReplacements && !replacementProviderId) {
|
||||
const first = replacementOptions[0];
|
||||
if (first) onReplacementChange(String(first.id));
|
||||
}
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
{/* TODO(@raunakab): migrate to opal Button once HoverIconButtonProps typing is resolved */}
|
||||
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
|
||||
{children}
|
||||
</Button>
|
||||
</div>
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectTarget.label}`}
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<OpalButton
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</OpalButton>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.label}</b> is currently the active{" "}
|
||||
{categoryLabel}. Search history will be preserved.
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => onReplacementChange(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement provider" />
|
||||
<InputSelect.Content>
|
||||
{replacementOptions.map((p) => (
|
||||
<InputSelect.Item key={p.id} value={String(p.id)}>
|
||||
{getLabel(p)}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03"> ({disableLabel})</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.label}</b> is currently the active{" "}
|
||||
{categoryLabel}.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Connect another provider to continue using {featureLabel}.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
{isSearch ? "Web search" : "Web crawling"} will no longer be routed
|
||||
through <b>{disconnectTarget.label}</b>.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Search history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -105,6 +226,11 @@ export default function Page() {
|
||||
WebProviderModalReducer,
|
||||
initialWebProviderModalState
|
||||
);
|
||||
const [disconnectTarget, setDisconnectTarget] =
|
||||
useState<DisconnectTargetState | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const [contentModal, dispatchContentModal] = useReducer(
|
||||
WebProviderModalReducer,
|
||||
initialWebProviderModalState
|
||||
@@ -113,8 +239,6 @@ export default function Page() {
|
||||
const [contentActivationError, setContentActivationError] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
data: searchProvidersData,
|
||||
error: searchProvidersError,
|
||||
@@ -833,6 +957,67 @@ export default function Page() {
|
||||
});
|
||||
};
|
||||
|
||||
const handleDisconnectProvider = async () => {
|
||||
if (!disconnectTarget) return;
|
||||
const { id, category } = disconnectTarget;
|
||||
|
||||
try {
|
||||
// If a replacement was selected (not "No Default"), activate it first
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
const repId = Number(replacementProviderId);
|
||||
const activateEndpoint =
|
||||
category === "search"
|
||||
? `/api/admin/web-search/search-providers/${repId}/activate`
|
||||
: `/api/admin/web-search/content-providers/${repId}/activate`;
|
||||
const activateResp = await fetch(activateEndpoint, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
if (!activateResp.ok) {
|
||||
const errorBody = await activateResp.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to activate replacement provider."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`/api/admin/web-search/${category}-providers/${id}`,
|
||||
{ method: "DELETE" }
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch((parseErr) => {
|
||||
console.error("Failed to parse disconnect error response:", parseErr);
|
||||
return {};
|
||||
});
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to disconnect provider."
|
||||
);
|
||||
}
|
||||
|
||||
toast.success(`${disconnectTarget.label} disconnected`);
|
||||
await mutateSearchProviders();
|
||||
await mutateContentProviders();
|
||||
} catch (error) {
|
||||
console.error("Failed to disconnect web search provider:", error);
|
||||
const message =
|
||||
error instanceof Error ? error.message : "Unexpected error occurred.";
|
||||
if (category === "search") {
|
||||
setActivationError(message);
|
||||
} else {
|
||||
setContentActivationError(message);
|
||||
}
|
||||
} finally {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingsLayouts.Root>
|
||||
@@ -894,149 +1079,79 @@ export default function Page() {
|
||||
provider
|
||||
);
|
||||
const isActive = provider?.is_active ?? false;
|
||||
const isHighlighted = isActive;
|
||||
const providerId = provider?.id;
|
||||
const canOpenModal =
|
||||
isBuiltInSearchProviderType(providerType);
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!provider || !isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
disabled: false,
|
||||
icon: "arrow" as const,
|
||||
onClick: canOpenModal
|
||||
const status: "disconnected" | "connected" | "selected" =
|
||||
!isConfigured
|
||||
? "disconnected"
|
||||
: isActive
|
||||
? "selected"
|
||||
: "connected";
|
||||
|
||||
return (
|
||||
<Select
|
||||
key={`${key}-${providerType}`}
|
||||
icon={() =>
|
||||
logoSrc ? (
|
||||
<Image
|
||||
src={logoSrc}
|
||||
alt={`${label} logo`}
|
||||
width={16}
|
||||
height={16}
|
||||
/>
|
||||
) : (
|
||||
<SvgGlobe size={16} />
|
||||
)
|
||||
}
|
||||
title={label}
|
||||
description={subtitle}
|
||||
status={status}
|
||||
onConnect={
|
||||
canOpenModal
|
||||
? () => {
|
||||
openSearchModal(providerType, provider);
|
||||
setActivationError(null);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
label: "Current Default",
|
||||
disabled: false,
|
||||
icon: "check" as const,
|
||||
onClick: providerId
|
||||
: undefined
|
||||
}
|
||||
onSelect={
|
||||
providerId
|
||||
? () => {
|
||||
void handleActivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onDeselect={
|
||||
providerId
|
||||
? () => {
|
||||
void handleDeactivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
disabled: false,
|
||||
icon: "arrow-circle" as const,
|
||||
onClick: providerId
|
||||
? () => {
|
||||
void handleActivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const buttonKey = `search-${key}-${providerType}`;
|
||||
const isButtonHovered = hoveredButtonKey === buttonKey;
|
||||
const isCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (isCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${key}-${providerType}`}
|
||||
onClick={isCardClickable ? handleCardClick : undefined}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
|
||||
isHighlighted
|
||||
? "border-action-link-05"
|
||||
: "border-border-01",
|
||||
isCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 px-2 py-1">
|
||||
{renderLogo({
|
||||
logoSrc,
|
||||
alt: `${label} logo`,
|
||||
size: 16,
|
||||
isHighlighted,
|
||||
})}
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
if (!canOpenModal) return;
|
||||
: undefined
|
||||
}
|
||||
onEdit={
|
||||
isConfigured && canOpenModal
|
||||
? () => {
|
||||
openSearchModal(
|
||||
providerType as WebSearchProviderType,
|
||||
provider
|
||||
);
|
||||
}}
|
||||
aria-label={`Edit ${label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isButtonHovered}
|
||||
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Disabled
|
||||
disabled={
|
||||
buttonState.disabled || !buttonState.onClick
|
||||
}
|
||||
>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
: undefined
|
||||
}
|
||||
onDisconnect={
|
||||
isConfigured && provider && provider.id > 0
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
id: provider.id,
|
||||
label,
|
||||
category: "search",
|
||||
providerType,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
)}
|
||||
@@ -1076,161 +1191,81 @@ export default function Page() {
|
||||
const isCurrentCrawler =
|
||||
provider.provider_type === currentContentProviderType;
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
icon: "arrow" as const,
|
||||
disabled: false,
|
||||
onClick: () => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
setContentActivationError(null);
|
||||
},
|
||||
};
|
||||
}
|
||||
const status: "disconnected" | "connected" | "selected" =
|
||||
!isConfigured
|
||||
? "disconnected"
|
||||
: isCurrentCrawler
|
||||
? "selected"
|
||||
: "connected";
|
||||
|
||||
if (isCurrentCrawler) {
|
||||
return {
|
||||
label: "Current Crawler",
|
||||
icon: "check" as const,
|
||||
disabled: false,
|
||||
onClick: () => {
|
||||
void handleDeactivateContentProvider(
|
||||
providerId,
|
||||
provider.provider_type
|
||||
);
|
||||
},
|
||||
};
|
||||
}
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
icon: "arrow-circle" as const,
|
||||
disabled: !canActivate,
|
||||
onClick: canActivate
|
||||
? () => {
|
||||
void handleActivateContentProvider(provider);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const contentButtonKey = `content-${provider.provider_type}-${provider.id}`;
|
||||
const isContentButtonHovered =
|
||||
hoveredButtonKey === contentButtonKey;
|
||||
const isContentCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleContentCardClick = () => {
|
||||
if (isContentCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
const contentLogoSrc =
|
||||
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.logoSrc;
|
||||
|
||||
return (
|
||||
<div
|
||||
<Select
|
||||
key={`${provider.provider_type}-${provider.id}`}
|
||||
onClick={
|
||||
isContentCardClickable
|
||||
? handleContentCardClick
|
||||
icon={() =>
|
||||
contentLogoSrc ? (
|
||||
<Image
|
||||
src={contentLogoSrc}
|
||||
alt={`${label} logo`}
|
||||
width={16}
|
||||
height={16}
|
||||
/>
|
||||
) : provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : (
|
||||
<SvgGlobe size={16} />
|
||||
)
|
||||
}
|
||||
title={label}
|
||||
description={subtitle}
|
||||
status={status}
|
||||
selectedLabel="Current Crawler"
|
||||
onConnect={() => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
setContentActivationError(null);
|
||||
}}
|
||||
onSelect={
|
||||
canActivate
|
||||
? () => {
|
||||
void handleActivateContentProvider(provider);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
|
||||
isCurrentCrawler
|
||||
? "border-action-link-05"
|
||||
: "border-border-01",
|
||||
isContentCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 px-2 py-1">
|
||||
{renderLogo({
|
||||
logoSrc:
|
||||
CONTENT_PROVIDER_DETAILS[provider.provider_type]
|
||||
?.logoSrc,
|
||||
alt: `${label} logo`,
|
||||
fallback:
|
||||
provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : undefined,
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
})}
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
openContentModal(
|
||||
provider.provider_type,
|
||||
provider
|
||||
);
|
||||
}}
|
||||
aria-label={`Edit ${label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isContentButtonHovered}
|
||||
onMouseEnter={() =>
|
||||
setHoveredButtonKey(contentButtonKey)
|
||||
onDeselect={() => {
|
||||
void handleDeactivateContentProvider(
|
||||
providerId,
|
||||
provider.provider_type
|
||||
);
|
||||
}}
|
||||
onEdit={
|
||||
provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured
|
||||
? () => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Disabled
|
||||
disabled={
|
||||
buttonState.disabled || !buttonState.onClick
|
||||
}
|
||||
>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
: undefined
|
||||
}
|
||||
onDisconnect={
|
||||
provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured &&
|
||||
provider.id > 0
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
id: provider.id,
|
||||
label,
|
||||
category: "content",
|
||||
providerType: provider.provider_type,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
@@ -1238,6 +1273,21 @@ export default function Page() {
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
|
||||
{disconnectTarget && (
|
||||
<WebSearchDisconnectModal
|
||||
disconnectTarget={disconnectTarget}
|
||||
searchProviders={searchProviders}
|
||||
contentProviders={combinedContentProviders}
|
||||
replacementProviderId={replacementProviderId}
|
||||
onReplacementChange={setReplacementProviderId}
|
||||
onClose={() => {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
onDisconnect={() => void handleDisconnectProvider()}
|
||||
/>
|
||||
)}
|
||||
|
||||
<WebProviderSetupModal
|
||||
isOpen={selectedProviderType !== null}
|
||||
onClose={() => {
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
background-color: var(--background-neutral-00);
|
||||
border: 1px solid var(--status-error-05);
|
||||
}
|
||||
.input-error:focus:not(:active),
|
||||
.input-error:focus-within:not(:active) {
|
||||
box-shadow: inset 0px 0px 0px 2px var(--background-tint-04);
|
||||
}
|
||||
|
||||
.input-disabled {
|
||||
background-color: var(--background-neutral-03);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
* and support various display sizes.
|
||||
*/
|
||||
import React from "react";
|
||||
import { SvgBifrost } from "@opal/icons";
|
||||
import { render } from "@tests/setup/test-utils";
|
||||
import { GithubIcon, GitbookIcon, ConfluenceIcon } from "./icons";
|
||||
|
||||
@@ -51,4 +52,15 @@ describe("Logo Icons", () => {
|
||||
render(<GithubIcon size={100} className="custom-class" />);
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
test("renders the Bifrost icon with theme-aware colors", () => {
|
||||
const { container } = render(
|
||||
<SvgBifrost size={32} className="custom text-red-500 dark:text-black" />
|
||||
);
|
||||
const icon = container.querySelector("svg");
|
||||
|
||||
expect(icon).toBeInTheDocument();
|
||||
expect(icon).toHaveClass("custom", "text-[#33C19E]", "dark:text-white");
|
||||
expect(icon).not.toHaveClass("text-red-500", "dark:text-black");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ export enum LLMProviderName {
|
||||
VERTEX_AI = "vertex_ai",
|
||||
BEDROCK = "bedrock",
|
||||
LITELLM_PROXY = "litellm_proxy",
|
||||
BIFROST = "bifrost",
|
||||
CUSTOM = "custom",
|
||||
}
|
||||
|
||||
@@ -165,6 +166,21 @@ export interface LiteLLMProxyModelResponse {
|
||||
model_name: string;
|
||||
}
|
||||
|
||||
export interface BifrostFetchParams {
|
||||
api_base?: string;
|
||||
api_key?: string;
|
||||
provider_name?: string;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface BifrostModelResponse {
|
||||
name: string;
|
||||
display_name: string;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface VertexAIFetchParams {
|
||||
model_configurations?: ModelConfiguration[];
|
||||
}
|
||||
@@ -182,5 +198,6 @@ export type FetchModelsParams =
|
||||
| OllamaFetchParams
|
||||
| OpenRouterFetchParams
|
||||
| LiteLLMProxyFetchParams
|
||||
| BifrostFetchParams
|
||||
| VertexAIFetchParams
|
||||
| LMStudioFetchParams;
|
||||
|
||||
@@ -37,6 +37,7 @@ export interface Settings {
|
||||
// User Knowledge settings
|
||||
user_knowledge_enabled?: boolean;
|
||||
user_file_max_upload_size_mb?: number | null;
|
||||
file_token_count_threshold_k?: number | null;
|
||||
|
||||
// Connector settings
|
||||
show_extra_connectors?: boolean;
|
||||
@@ -68,6 +69,12 @@ export interface Settings {
|
||||
|
||||
// Application version from the ONYX_VERSION env var on the server.
|
||||
version?: string | null;
|
||||
// Hard ceiling for user_file_max_upload_size_mb, derived from env var.
|
||||
max_allowed_upload_size_mb?: number;
|
||||
|
||||
// Factory defaults for the restore button.
|
||||
default_user_file_max_upload_size_mb?: number;
|
||||
default_file_token_count_threshold_k?: number;
|
||||
}
|
||||
|
||||
export enum NotificationType {
|
||||
|
||||
@@ -53,6 +53,12 @@ export async function fetchVoicesByType(
|
||||
return fetch(`/api/admin/voice/voices?provider_type=${providerType}`);
|
||||
}
|
||||
|
||||
export async function deleteVoiceProvider(
|
||||
providerId: number
|
||||
): Promise<Response> {
|
||||
return fetch(`${VOICE_PROVIDERS_URL}/${providerId}`, { method: "DELETE" });
|
||||
}
|
||||
|
||||
export async function fetchLLMProviders(): Promise<Response> {
|
||||
return fetch("/api/admin/llm/provider");
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import {
|
||||
SvgBifrost,
|
||||
SvgCpu,
|
||||
SvgOpenai,
|
||||
SvgClaude,
|
||||
@@ -26,6 +27,7 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: SvgServer,
|
||||
@@ -42,6 +44,7 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Custom Models",
|
||||
@@ -58,6 +61,7 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Other providers or self-hosted",
|
||||
|
||||
@@ -85,8 +85,6 @@ function buildFileKey(file: File): string {
|
||||
return `${file.size}|${namePrefix}`;
|
||||
}
|
||||
|
||||
const DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = 50;
|
||||
|
||||
interface ProjectsContextType {
|
||||
projects: Project[];
|
||||
recentFiles: ProjectFile[];
|
||||
@@ -341,21 +339,20 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
|
||||
onFailure?: (failedTempIds: string[]) => void
|
||||
): Promise<ProjectFile[]> => {
|
||||
const rawMax = settingsContext?.settings?.user_file_max_upload_size_mb;
|
||||
const maxUploadSizeMb =
|
||||
rawMax && rawMax > 0 ? rawMax : DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB;
|
||||
const maxUploadSizeBytes = maxUploadSizeMb * 1024 * 1024;
|
||||
|
||||
const oversizedFiles = files.filter(
|
||||
(file) => file.size > maxUploadSizeBytes
|
||||
);
|
||||
const validFiles = files.filter(
|
||||
(file) => file.size <= maxUploadSizeBytes
|
||||
);
|
||||
const oversizedFiles =
|
||||
rawMax && rawMax > 0
|
||||
? files.filter((file) => file.size > rawMax * 1024 * 1024)
|
||||
: [];
|
||||
const validFiles =
|
||||
rawMax && rawMax > 0
|
||||
? files.filter((file) => file.size <= rawMax * 1024 * 1024)
|
||||
: files;
|
||||
|
||||
if (oversizedFiles.length > 0) {
|
||||
const skippedNames = oversizedFiles.map((file) => file.name).join(", ");
|
||||
toast.warning(
|
||||
`Skipped ${oversizedFiles.length} oversized file(s) (>${maxUploadSizeMb} MB): ${skippedNames}`
|
||||
`Skipped ${oversizedFiles.length} oversized file(s) (>${rawMax} MB): ${skippedNames}`
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -457,6 +457,7 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
|
||||
<div
|
||||
tabIndex={-1}
|
||||
ref={closeButtonRef as React.RefObject<HTMLDivElement>}
|
||||
className="outline-none"
|
||||
>
|
||||
<DialogPrimitive.Close asChild>
|
||||
<Button
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
SvgArrowRightCircle,
|
||||
SvgCheckSquare,
|
||||
SvgSettings,
|
||||
SvgUnplug,
|
||||
} from "@opal/icons";
|
||||
|
||||
const containerClasses = {
|
||||
@@ -35,6 +36,7 @@ export interface SelectProps
|
||||
onSelect?: () => void;
|
||||
onDeselect?: () => void;
|
||||
onEdit?: () => void;
|
||||
onDisconnect?: () => void;
|
||||
|
||||
// Labels (customizable)
|
||||
connectLabel?: string;
|
||||
@@ -59,6 +61,7 @@ export default function Select({
|
||||
onSelect,
|
||||
onDeselect,
|
||||
onEdit,
|
||||
onDisconnect,
|
||||
connectLabel = "Connect",
|
||||
selectLabel = "Set as Default",
|
||||
selectedLabel = "Current Default",
|
||||
@@ -68,7 +71,7 @@ export default function Select({
|
||||
disabled,
|
||||
...rest
|
||||
}: SelectProps) {
|
||||
const sizeClass = medium ? "h-[3.75rem]" : "h-[4.25rem]";
|
||||
const sizeClass = medium ? "h-[3.75rem]" : "min-h-[3.75rem] max-h-[5.25rem]";
|
||||
const containerClass = containerClasses[status];
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
@@ -121,7 +124,7 @@ export default function Select({
|
||||
</div>
|
||||
|
||||
{/* Right section - Actions */}
|
||||
<div className="flex items-center justify-end gap-1">
|
||||
<div className="flex flex-col h-full items-end justify-between gap-1">
|
||||
{/* Disconnected: Show Connect button */}
|
||||
{isDisconnected && (
|
||||
<Disabled disabled={disabled || !onConnect}>
|
||||
@@ -149,18 +152,32 @@ export default function Select({
|
||||
{selectLabel}
|
||||
</SelectButton>
|
||||
</Disabled>
|
||||
{onEdit && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onEdit)}
|
||||
aria-label={`Edit ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
<div className="flex px-1 gap-1">
|
||||
{onDisconnect && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgUnplug}
|
||||
tooltip="Disconnect"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onDisconnect)}
|
||||
aria-label={`Disconnect ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
{onEdit && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onEdit)}
|
||||
aria-label={`Edit ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -177,18 +194,32 @@ export default function Select({
|
||||
{selectedLabel}
|
||||
</SelectButton>
|
||||
</Disabled>
|
||||
{onEdit && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onEdit)}
|
||||
aria-label={`Edit ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
<div className="flex px-1 gap-1">
|
||||
{onDisconnect && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgUnplug}
|
||||
tooltip="Disconnect"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onDisconnect)}
|
||||
aria-label={`Disconnect ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
{onEdit && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onEdit)}
|
||||
aria-label={`Edit ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -12,7 +12,7 @@ export interface FieldContextType {
|
||||
|
||||
export type FormFieldRootProps = React.HTMLAttributes<HTMLDivElement> & {
|
||||
name?: string;
|
||||
state: FormFieldState;
|
||||
state?: FormFieldState;
|
||||
required?: boolean;
|
||||
id?: string;
|
||||
};
|
||||
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
SvgFold,
|
||||
SvgExternalLink,
|
||||
SvgAlertCircle,
|
||||
SvgRefreshCw,
|
||||
} from "@opal/icons";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import { Content } from "@opal/layouts";
|
||||
@@ -54,6 +55,7 @@ import * as ExpandableCard from "@/layouts/expandable-card-layouts";
|
||||
import * as ActionsLayouts from "@/layouts/actions-layouts";
|
||||
import { getActionIcon } from "@/lib/tools/mcpUtils";
|
||||
import { Disabled } from "@opal/core";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import useFilter from "@/hooks/useFilter";
|
||||
import { MCPServer } from "@/lib/tools/interfaces";
|
||||
@@ -81,6 +83,10 @@ interface ChatPreferencesFormValues {
|
||||
maximum_chat_retention_days: string;
|
||||
anonymous_user_enabled: boolean;
|
||||
disable_default_assistant: boolean;
|
||||
|
||||
// File limits
|
||||
user_file_max_upload_size_mb: string;
|
||||
file_token_count_threshold_k: string;
|
||||
}
|
||||
|
||||
interface MCPServerCardTool {
|
||||
@@ -185,6 +191,173 @@ function MCPServerCard({
|
||||
);
|
||||
}
|
||||
|
||||
type FileLimitFieldName =
|
||||
| "user_file_max_upload_size_mb"
|
||||
| "file_token_count_threshold_k";
|
||||
|
||||
interface NumericLimitFieldProps {
|
||||
name: FileLimitFieldName;
|
||||
defaultValue: string;
|
||||
saveSettings: (updates: Partial<Settings>) => Promise<void>;
|
||||
maxValue?: number;
|
||||
allowZero?: boolean;
|
||||
}
|
||||
|
||||
function NumericLimitField({
|
||||
name,
|
||||
defaultValue,
|
||||
saveSettings,
|
||||
maxValue,
|
||||
allowZero = false,
|
||||
}: NumericLimitFieldProps) {
|
||||
const { values, setFieldValue } =
|
||||
useFormikContext<ChatPreferencesFormValues>();
|
||||
const initialValue = useRef(values[name]);
|
||||
const restoringRef = useRef(false);
|
||||
const value = values[name];
|
||||
|
||||
const parsed = parseInt(value, 10);
|
||||
const isOverMax =
|
||||
maxValue !== undefined && !isNaN(parsed) && parsed > maxValue;
|
||||
|
||||
const handleRestore = () => {
|
||||
restoringRef.current = true;
|
||||
initialValue.current = defaultValue;
|
||||
void setFieldValue(name, defaultValue);
|
||||
void saveSettings({ [name]: parseInt(defaultValue, 10) });
|
||||
};
|
||||
|
||||
const handleBlur = () => {
|
||||
// The restore button triggers a blur — skip since handleRestore already saved.
|
||||
if (restoringRef.current) {
|
||||
restoringRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = parseInt(value, 10);
|
||||
const isValid = !isNaN(parsed) && (allowZero ? parsed >= 0 : parsed > 0);
|
||||
|
||||
// Revert invalid input (empty, NaN, negative).
|
||||
if (!isValid) {
|
||||
if (allowZero) {
|
||||
// Empty/invalid means "no limit" — persist 0 and clear the field.
|
||||
void setFieldValue(name, "");
|
||||
void saveSettings({ [name]: 0 });
|
||||
initialValue.current = "";
|
||||
} else {
|
||||
void setFieldValue(name, initialValue.current);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Block save when the value exceeds the hard ceiling.
|
||||
if (maxValue !== undefined && parsed > maxValue) {
|
||||
return;
|
||||
}
|
||||
|
||||
// For allowZero fields, 0 means "no limit" — clear the display
|
||||
// so the "No limit" placeholder is visible, but still persist 0.
|
||||
if (allowZero && parsed === 0) {
|
||||
void setFieldValue(name, "");
|
||||
if (initialValue.current !== "") {
|
||||
void saveSettings({ [name]: 0 });
|
||||
initialValue.current = "";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedDisplay = String(parsed);
|
||||
|
||||
// Update the display to the canonical form (e.g. strip leading zeros).
|
||||
if (value !== normalizedDisplay) {
|
||||
void setFieldValue(name, normalizedDisplay);
|
||||
}
|
||||
|
||||
// Persist only when the value actually changed.
|
||||
if (normalizedDisplay !== initialValue.current) {
|
||||
void saveSettings({ [name]: parsed });
|
||||
initialValue.current = normalizedDisplay;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="group w-full">
|
||||
<InputTypeInField
|
||||
name={name}
|
||||
inputMode="numeric"
|
||||
showClearButton={false}
|
||||
pattern="[0-9]*"
|
||||
placeholder={allowZero ? "No limit" : `Default: ${defaultValue}`}
|
||||
variant={isOverMax ? "error" : undefined}
|
||||
rightSection={
|
||||
(value || "") !== defaultValue ? (
|
||||
<div className="opacity-0 group-hover:opacity-100 group-focus-within:opacity-100 transition-opacity">
|
||||
<IconButton
|
||||
icon={SvgRefreshCw}
|
||||
tooltip="Restore default"
|
||||
internal
|
||||
type="button"
|
||||
onClick={handleRestore}
|
||||
/>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
onBlur={handleBlur}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface FileSizeLimitFieldsProps {
|
||||
saveSettings: (updates: Partial<Settings>) => Promise<void>;
|
||||
defaultUploadSizeMb: string;
|
||||
defaultTokenThresholdK: string;
|
||||
maxAllowedUploadSizeMb?: number;
|
||||
}
|
||||
|
||||
function FileSizeLimitFields({
|
||||
saveSettings,
|
||||
defaultUploadSizeMb,
|
||||
defaultTokenThresholdK,
|
||||
maxAllowedUploadSizeMb,
|
||||
}: FileSizeLimitFieldsProps) {
|
||||
return (
|
||||
<div className="flex gap-4 w-full items-start">
|
||||
<div className="flex-1">
|
||||
<InputLayouts.Vertical
|
||||
title="File Size Limit (MB)"
|
||||
subDescription={
|
||||
maxAllowedUploadSizeMb
|
||||
? `Max: ${maxAllowedUploadSizeMb} MB`
|
||||
: undefined
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<NumericLimitField
|
||||
name="user_file_max_upload_size_mb"
|
||||
defaultValue={defaultUploadSizeMb}
|
||||
saveSettings={saveSettings}
|
||||
maxValue={maxAllowedUploadSizeMb}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<InputLayouts.Vertical
|
||||
title="File Token Limit (thousand tokens)"
|
||||
nonInteractive
|
||||
>
|
||||
<NumericLimitField
|
||||
name="file_token_count_threshold_k"
|
||||
defaultValue={defaultTokenThresholdK}
|
||||
saveSettings={saveSettings}
|
||||
allowZero
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Inner form component that uses useFormikContext to access values
|
||||
* and create save handlers for settings fields.
|
||||
@@ -201,6 +374,7 @@ function ChatPreferencesForm() {
|
||||
// Tools availability
|
||||
const { tools: availableTools } = useAvailableTools();
|
||||
const vectorDbEnabled = useVectorDbEnabled();
|
||||
|
||||
const searchTool = availableTools.find(
|
||||
(t) => t.in_code_tool_id === SEARCH_TOOL_ID
|
||||
);
|
||||
@@ -723,6 +897,28 @@ function ChatPreferencesForm() {
|
||||
</InputLayouts.Horizontal>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<InputLayouts.Vertical
|
||||
title="File Attachment Size Limit"
|
||||
description="Files attached in chats and projects must fit within both limits to be accepted. Larger files increase latency, memory usage, and token costs."
|
||||
>
|
||||
<FileSizeLimitFields
|
||||
saveSettings={saveSettings}
|
||||
defaultUploadSizeMb={
|
||||
settings?.settings.default_user_file_max_upload_size_mb?.toString() ??
|
||||
"100"
|
||||
}
|
||||
defaultTokenThresholdK={
|
||||
settings?.settings.default_file_token_count_threshold_k?.toString() ??
|
||||
"200"
|
||||
}
|
||||
maxAllowedUploadSizeMb={
|
||||
settings?.settings.max_allowed_upload_size_mb
|
||||
}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Allow Anonymous Users"
|
||||
@@ -862,6 +1058,21 @@ export default function ChatPreferencesPage() {
|
||||
anonymous_user_enabled: settings.settings.anonymous_user_enabled ?? false,
|
||||
disable_default_assistant:
|
||||
settings.settings.disable_default_assistant ?? false,
|
||||
|
||||
// File limits — for upload size: 0/null means "use default";
|
||||
// for token threshold: null means "use default", 0 means "no limit".
|
||||
user_file_max_upload_size_mb:
|
||||
(settings.settings.user_file_max_upload_size_mb ?? 0) <= 0
|
||||
? settings.settings.default_user_file_max_upload_size_mb?.toString() ??
|
||||
"100"
|
||||
: settings.settings.user_file_max_upload_size_mb!.toString(),
|
||||
file_token_count_threshold_k:
|
||||
settings.settings.file_token_count_threshold_k == null
|
||||
? settings.settings.default_file_token_count_threshold_k?.toString() ??
|
||||
"200"
|
||||
: settings.settings.file_token_count_threshold_k === 0
|
||||
? ""
|
||||
: settings.settings.file_token_count_threshold_k.toString(),
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
559
web/src/refresh-pages/admin/HooksPage/HookFormModal.tsx
Normal file
559
web/src/refresh-pages/admin/HooksPage/HookFormModal.tsx
Normal file
@@ -0,0 +1,559 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import {
|
||||
SvgCheckCircle,
|
||||
SvgHookNodes,
|
||||
SvgLoader,
|
||||
SvgRevert,
|
||||
} from "@opal/icons";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
|
||||
import { FormField } from "@/refresh-components/form/FormField";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
createHook,
|
||||
updateHook,
|
||||
HookAuthError,
|
||||
HookTimeoutError,
|
||||
HookConnectError,
|
||||
} from "@/refresh-pages/admin/HooksPage/svc";
|
||||
import type {
|
||||
HookFailStrategy,
|
||||
HookFormState,
|
||||
HookPointMeta,
|
||||
HookResponse,
|
||||
HookUpdateRequest,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface HookFormModalProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
/** When provided, the modal is in edit mode for this hook. */
|
||||
hook?: HookResponse;
|
||||
/** When provided (create mode), the hook point is pre-selected and locked. */
|
||||
spec?: HookPointMeta;
|
||||
onSuccess: (hook: HookResponse) => void;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function buildInitialState(
|
||||
hook: HookResponse | undefined,
|
||||
spec: HookPointMeta | undefined
|
||||
): HookFormState {
|
||||
if (hook) {
|
||||
return {
|
||||
name: hook.name,
|
||||
endpoint_url: hook.endpoint_url ?? "",
|
||||
api_key: "",
|
||||
fail_strategy: hook.fail_strategy,
|
||||
timeout_seconds: String(hook.timeout_seconds),
|
||||
};
|
||||
}
|
||||
return {
|
||||
name: "",
|
||||
endpoint_url: "",
|
||||
api_key: "",
|
||||
fail_strategy: spec?.default_fail_strategy ?? "hard",
|
||||
timeout_seconds: spec ? String(spec.default_timeout_seconds) : "30",
|
||||
};
|
||||
}
|
||||
|
||||
const SOFT_DESCRIPTION =
|
||||
"If the endpoint returns an error, Onyx logs it and continues the pipeline as normal, ignoring the hook result.";
|
||||
|
||||
const MAX_TIMEOUT_SECONDS = 600;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function HookFormModal({
|
||||
open,
|
||||
onOpenChange,
|
||||
hook,
|
||||
spec,
|
||||
onSuccess,
|
||||
}: HookFormModalProps) {
|
||||
const isEdit = !!hook;
|
||||
const [form, setForm] = useState<HookFormState>(() =>
|
||||
buildInitialState(hook, spec)
|
||||
);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
// Tracks whether the user explicitly cleared the API key field in edit mode.
|
||||
// - false + empty field → key unchanged (omitted from PATCH)
|
||||
// - true + empty field → key cleared (api_key: null sent to backend)
|
||||
// - false + non-empty → new key provided (new value sent to backend)
|
||||
const [apiKeyCleared, setApiKeyCleared] = useState(false);
|
||||
const [touched, setTouched] = useState({
|
||||
name: false,
|
||||
endpoint_url: false,
|
||||
api_key: false,
|
||||
});
|
||||
const [apiKeyServerError, setApiKeyServerError] = useState(false);
|
||||
const [endpointServerError, setEndpointServerError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [timeoutServerError, setTimeoutServerError] = useState(false);
|
||||
|
||||
function touch(key: keyof typeof touched) {
|
||||
setTouched((prev) => ({ ...prev, [key]: true }));
|
||||
}
|
||||
|
||||
function handleOpenChange(next: boolean) {
|
||||
if (!next) {
|
||||
if (isSubmitting) return;
|
||||
setTimeout(() => {
|
||||
setForm(buildInitialState(hook, spec));
|
||||
setIsConnected(false);
|
||||
setApiKeyCleared(false);
|
||||
setTouched({ name: false, endpoint_url: false, api_key: false });
|
||||
setApiKeyServerError(false);
|
||||
setEndpointServerError(null);
|
||||
setTimeoutServerError(false);
|
||||
}, 200);
|
||||
}
|
||||
onOpenChange(next);
|
||||
}
|
||||
|
||||
function set<K extends keyof HookFormState>(key: K, value: HookFormState[K]) {
|
||||
setForm((prev) => ({ ...prev, [key]: value }));
|
||||
}
|
||||
|
||||
const timeoutNum = parseFloat(form.timeout_seconds);
|
||||
const isTimeoutValid =
|
||||
!isNaN(timeoutNum) && timeoutNum > 0 && timeoutNum <= MAX_TIMEOUT_SECONDS;
|
||||
const isValid =
|
||||
form.name.trim().length > 0 &&
|
||||
form.endpoint_url.trim().length > 0 &&
|
||||
isTimeoutValid &&
|
||||
(isEdit || form.api_key.trim().length > 0);
|
||||
|
||||
const nameError = touched.name && !form.name.trim();
|
||||
const endpointEmptyError = touched.endpoint_url && !form.endpoint_url.trim();
|
||||
const endpointFieldError = endpointEmptyError
|
||||
? "Endpoint URL cannot be empty."
|
||||
: endpointServerError ?? undefined;
|
||||
const apiKeyEmptyError = !isEdit && touched.api_key && !form.api_key.trim();
|
||||
const apiKeyFieldError = apiKeyEmptyError
|
||||
? "API key cannot be empty."
|
||||
: apiKeyServerError
|
||||
? "Invalid API key."
|
||||
: undefined;
|
||||
|
||||
function handleTimeoutBlur() {
|
||||
if (!isTimeoutValid) {
|
||||
const fallback = hook?.timeout_seconds ?? spec?.default_timeout_seconds;
|
||||
if (fallback !== undefined) {
|
||||
set("timeout_seconds", String(fallback));
|
||||
if (timeoutServerError) setTimeoutServerError(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const hasChanges =
|
||||
isEdit && hook
|
||||
? form.name !== hook.name ||
|
||||
form.endpoint_url !== (hook.endpoint_url ?? "") ||
|
||||
form.fail_strategy !== hook.fail_strategy ||
|
||||
timeoutNum !== hook.timeout_seconds ||
|
||||
form.api_key.trim().length > 0 ||
|
||||
apiKeyCleared
|
||||
: true;
|
||||
|
||||
async function handleSubmit() {
|
||||
if (!isValid) return;
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
let result: HookResponse;
|
||||
if (isEdit && hook) {
|
||||
const req: HookUpdateRequest = {};
|
||||
if (form.name !== hook.name) req.name = form.name;
|
||||
if (form.endpoint_url !== (hook.endpoint_url ?? ""))
|
||||
req.endpoint_url = form.endpoint_url;
|
||||
if (form.fail_strategy !== hook.fail_strategy)
|
||||
req.fail_strategy = form.fail_strategy;
|
||||
if (timeoutNum !== hook.timeout_seconds)
|
||||
req.timeout_seconds = timeoutNum;
|
||||
if (form.api_key.trim().length > 0) {
|
||||
req.api_key = form.api_key;
|
||||
} else if (apiKeyCleared) {
|
||||
req.api_key = null;
|
||||
}
|
||||
if (Object.keys(req).length === 0) {
|
||||
setIsSubmitting(false);
|
||||
handleOpenChange(false);
|
||||
return;
|
||||
}
|
||||
result = await updateHook(hook.id, req);
|
||||
} else {
|
||||
if (!spec) {
|
||||
toast.error("No hook point specified.");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
result = await createHook({
|
||||
name: form.name,
|
||||
hook_point: spec.hook_point,
|
||||
endpoint_url: form.endpoint_url,
|
||||
...(form.api_key ? { api_key: form.api_key } : {}),
|
||||
fail_strategy: form.fail_strategy,
|
||||
timeout_seconds: timeoutNum,
|
||||
});
|
||||
}
|
||||
toast.success(isEdit ? "Hook updated." : "Hook created.");
|
||||
onSuccess(result);
|
||||
if (!isEdit) {
|
||||
setIsConnected(true);
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
}
|
||||
setIsSubmitting(false);
|
||||
handleOpenChange(false);
|
||||
} catch (err) {
|
||||
if (err instanceof HookAuthError) {
|
||||
setApiKeyServerError(true);
|
||||
} else if (err instanceof HookTimeoutError) {
|
||||
setTimeoutServerError(true);
|
||||
} else if (err instanceof HookConnectError) {
|
||||
setEndpointServerError(err.message || "Could not connect to endpoint.");
|
||||
} else {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Something went wrong."
|
||||
);
|
||||
}
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
const hookPointDisplayName =
|
||||
spec?.display_name ?? spec?.hook_point ?? hook?.hook_point ?? "";
|
||||
const hookPointDescription = spec?.description;
|
||||
const docsUrl = spec?.docs_url;
|
||||
|
||||
const failStrategyDescription =
|
||||
form.fail_strategy === "soft"
|
||||
? SOFT_DESCRIPTION
|
||||
: spec?.fail_hard_description;
|
||||
|
||||
return (
|
||||
<Modal open={open} onOpenChange={handleOpenChange}>
|
||||
<Modal.Content width="md" height="fit">
|
||||
<Modal.Header
|
||||
icon={SvgHookNodes}
|
||||
title={isEdit ? "Manage Hook Extension" : "Set Up Hook Extension"}
|
||||
description={
|
||||
isEdit
|
||||
? undefined
|
||||
: "Connect an external API endpoint to extend the hook point."
|
||||
}
|
||||
onClose={() => handleOpenChange(false)}
|
||||
/>
|
||||
|
||||
<Modal.Body>
|
||||
{/* Hook point section header */}
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="fit"
|
||||
title={hookPointDisplayName}
|
||||
description={hookPointDescription}
|
||||
rightChildren={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
alignItems="end"
|
||||
width="fit"
|
||||
height="fit"
|
||||
gap={0.25}
|
||||
>
|
||||
<div className="flex items-center gap-1">
|
||||
<SvgHookNodes
|
||||
style={{ width: "1rem", height: "1rem" }}
|
||||
className="text-text-03 shrink-0"
|
||||
/>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Hook Point
|
||||
</Text>
|
||||
</div>
|
||||
{docsUrl && (
|
||||
<a
|
||||
href={docsUrl}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Documentation
|
||||
</Text>
|
||||
</a>
|
||||
)}
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
|
||||
<FormField className="w-full" state={nameError ? "error" : "idle"}>
|
||||
<FormField.Label>Display Name</FormField.Label>
|
||||
<FormField.Control>
|
||||
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
|
||||
<InputTypeIn
|
||||
value={form.name}
|
||||
onChange={(e) => set("name", e.target.value)}
|
||||
onBlur={() => touch("name")}
|
||||
placeholder="Name your extension at this hook point"
|
||||
variant={
|
||||
isSubmitting ? "disabled" : nameError ? "error" : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</FormField.Control>
|
||||
<FormField.Message
|
||||
messages={{ error: "Display name cannot be empty." }}
|
||||
/>
|
||||
</FormField>
|
||||
|
||||
<FormField className="w-full">
|
||||
<FormField.Label>Fail Strategy</FormField.Label>
|
||||
<FormField.Control>
|
||||
<InputSelect
|
||||
value={form.fail_strategy}
|
||||
onValueChange={(v) =>
|
||||
set("fail_strategy", v as HookFailStrategy)
|
||||
}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select strategy" />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="soft">
|
||||
Log Error and Continue
|
||||
{spec?.default_fail_strategy === "soft" && (
|
||||
<>
|
||||
{" "}
|
||||
<Text color="text-03">(Default)</Text>
|
||||
</>
|
||||
)}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="hard">
|
||||
Block Pipeline on Failure
|
||||
{spec?.default_fail_strategy === "hard" && (
|
||||
<>
|
||||
{" "}
|
||||
<Text color="text-03">(Default)</Text>
|
||||
</>
|
||||
)}
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</FormField.Control>
|
||||
<FormField.Description>
|
||||
{failStrategyDescription}
|
||||
</FormField.Description>
|
||||
</FormField>
|
||||
|
||||
<FormField
|
||||
className="w-full"
|
||||
state={timeoutServerError ? "error" : "idle"}
|
||||
>
|
||||
<FormField.Label>
|
||||
Timeout{" "}
|
||||
<Text font="main-ui-action" color="text-03">
|
||||
(seconds)
|
||||
</Text>
|
||||
</FormField.Label>
|
||||
<FormField.Control>
|
||||
<div className="[&_input]:!font-main-ui-mono [&_input::placeholder]:!font-main-ui-mono [&_input]:![appearance:textfield] [&_input::-webkit-outer-spin-button]:!appearance-none [&_input::-webkit-inner-spin-button]:!appearance-none w-full">
|
||||
<InputTypeIn
|
||||
type="number"
|
||||
value={form.timeout_seconds}
|
||||
onChange={(e) => {
|
||||
set("timeout_seconds", e.target.value);
|
||||
if (timeoutServerError) setTimeoutServerError(false);
|
||||
}}
|
||||
onBlur={handleTimeoutBlur}
|
||||
placeholder={
|
||||
spec ? String(spec.default_timeout_seconds) : undefined
|
||||
}
|
||||
variant={
|
||||
isSubmitting
|
||||
? "disabled"
|
||||
: timeoutServerError
|
||||
? "error"
|
||||
: undefined
|
||||
}
|
||||
showClearButton={false}
|
||||
rightSection={
|
||||
spec?.default_timeout_seconds !== undefined &&
|
||||
form.timeout_seconds !==
|
||||
String(spec.default_timeout_seconds) ? (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="xs"
|
||||
icon={SvgRevert}
|
||||
tooltip="Revert to Default"
|
||||
onClick={() =>
|
||||
set(
|
||||
"timeout_seconds",
|
||||
String(spec.default_timeout_seconds)
|
||||
)
|
||||
}
|
||||
disabled={isSubmitting}
|
||||
/>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</FormField.Control>
|
||||
{!timeoutServerError && (
|
||||
<FormField.Description>
|
||||
Maximum time Onyx will wait for the endpoint to respond before
|
||||
applying the fail strategy. Must be greater than 0 and at most{" "}
|
||||
{MAX_TIMEOUT_SECONDS} seconds.
|
||||
</FormField.Description>
|
||||
)}
|
||||
<FormField.Message
|
||||
messages={{
|
||||
error: "Connection timed out. Try increasing the timeout.",
|
||||
}}
|
||||
/>
|
||||
</FormField>
|
||||
|
||||
<FormField
|
||||
className="w-full"
|
||||
state={endpointFieldError ? "error" : "idle"}
|
||||
>
|
||||
<FormField.Label>External API Endpoint URL</FormField.Label>
|
||||
<FormField.Control>
|
||||
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
|
||||
<InputTypeIn
|
||||
value={form.endpoint_url}
|
||||
onChange={(e) => {
|
||||
set("endpoint_url", e.target.value);
|
||||
if (endpointServerError) setEndpointServerError(null);
|
||||
}}
|
||||
onBlur={() => touch("endpoint_url")}
|
||||
placeholder="https://your-api-endpoint.com"
|
||||
variant={
|
||||
isSubmitting
|
||||
? "disabled"
|
||||
: endpointFieldError
|
||||
? "error"
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</FormField.Control>
|
||||
{!endpointFieldError && (
|
||||
<FormField.Description>
|
||||
Only connect to servers you trust. You are responsible for
|
||||
actions taken and data shared with this connection.
|
||||
</FormField.Description>
|
||||
)}
|
||||
<FormField.Message messages={{ error: endpointFieldError }} />
|
||||
</FormField>
|
||||
|
||||
<FormField
|
||||
className="w-full"
|
||||
state={apiKeyFieldError ? "error" : "idle"}
|
||||
>
|
||||
<FormField.Label>API Key</FormField.Label>
|
||||
<FormField.Control>
|
||||
<PasswordInputTypeIn
|
||||
value={form.api_key}
|
||||
onChange={(e) => {
|
||||
set("api_key", e.target.value);
|
||||
if (apiKeyServerError) setApiKeyServerError(false);
|
||||
if (isEdit) {
|
||||
setApiKeyCleared(
|
||||
e.target.value === "" && !!hook?.api_key_masked
|
||||
);
|
||||
}
|
||||
}}
|
||||
onBlur={() => touch("api_key")}
|
||||
placeholder={
|
||||
isEdit
|
||||
? hook?.api_key_masked ?? "Leave blank to keep current key"
|
||||
: undefined
|
||||
}
|
||||
disabled={isSubmitting}
|
||||
error={!!apiKeyFieldError}
|
||||
/>
|
||||
</FormField.Control>
|
||||
{!apiKeyFieldError && (
|
||||
<FormField.Description>
|
||||
Onyx will use this key to authenticate with your API endpoint.
|
||||
</FormField.Description>
|
||||
)}
|
||||
<FormField.Message messages={{ error: apiKeyFieldError }} />
|
||||
</FormField>
|
||||
|
||||
{!isEdit && (isSubmitting || isConnected) && (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
alignItems="center"
|
||||
justifyContent="start"
|
||||
height="fit"
|
||||
gap={1}
|
||||
className="px-0.5"
|
||||
>
|
||||
<div className="p-0.5 shrink-0">
|
||||
{isConnected ? (
|
||||
<SvgCheckCircle
|
||||
size={16}
|
||||
className="text-status-success-05"
|
||||
/>
|
||||
) : (
|
||||
<SvgLoader size={16} className="animate-spin text-text-03" />
|
||||
)}
|
||||
</div>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{isConnected ? "Connection valid." : "Verifying connection…"}
|
||||
</Text>
|
||||
</Section>
|
||||
)}
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
cancel={
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button
|
||||
prominence="secondary"
|
||||
onClick={() => handleOpenChange(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</Disabled>
|
||||
}
|
||||
submit={
|
||||
<Disabled disabled={isSubmitting || !isValid || !hasChanges}>
|
||||
<Button
|
||||
onClick={handleSubmit}
|
||||
icon={
|
||||
isSubmitting && !isEdit
|
||||
? () => <SvgLoader size={16} className="animate-spin" />
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{isEdit ? "Save Changes" : "Connect"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
}
|
||||
/>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -29,6 +29,14 @@ export interface HookResponse {
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export interface HookFormState {
|
||||
name: string;
|
||||
endpoint_url: string;
|
||||
api_key: string;
|
||||
fail_strategy: HookFailStrategy;
|
||||
timeout_seconds: string;
|
||||
}
|
||||
|
||||
export interface HookCreateRequest {
|
||||
name: string;
|
||||
hook_point: HookPoint;
|
||||
|
||||
@@ -5,15 +5,27 @@ import {
|
||||
HookValidateResponse,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
async function parseErrorDetail(
|
||||
res: Response,
|
||||
fallback: string
|
||||
): Promise<string> {
|
||||
export class HookAuthError extends Error {}
|
||||
export class HookTimeoutError extends Error {}
|
||||
export class HookConnectError extends Error {}
|
||||
|
||||
async function parseError(res: Response, fallback: string): Promise<Error> {
|
||||
try {
|
||||
const body = await res.json();
|
||||
return body?.detail ?? fallback;
|
||||
if (body?.error_code === "CREDENTIAL_INVALID") {
|
||||
return new HookAuthError(body?.detail ?? "Invalid API key.");
|
||||
}
|
||||
if (body?.error_code === "GATEWAY_TIMEOUT") {
|
||||
return new HookTimeoutError(body?.detail ?? "Connection timed out.");
|
||||
}
|
||||
if (body?.error_code === "BAD_GATEWAY") {
|
||||
return new HookConnectError(
|
||||
body?.detail ?? "Could not connect to endpoint."
|
||||
);
|
||||
}
|
||||
return new Error(body?.detail ?? fallback);
|
||||
} catch {
|
||||
return fallback;
|
||||
return new Error(fallback);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +38,7 @@ export async function createHook(
|
||||
body: JSON.stringify(req),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to create hook"));
|
||||
throw await parseError(res, "Failed to create hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -41,7 +53,7 @@ export async function updateHook(
|
||||
body: JSON.stringify(req),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to update hook"));
|
||||
throw await parseError(res, "Failed to update hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -49,7 +61,7 @@ export async function updateHook(
|
||||
export async function deleteHook(id: number): Promise<void> {
|
||||
const res = await fetch(`/api/admin/hooks/${id}`, { method: "DELETE" });
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to delete hook"));
|
||||
throw await parseError(res, "Failed to delete hook");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +70,7 @@ export async function activateHook(id: number): Promise<HookResponse> {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to activate hook"));
|
||||
throw await parseError(res, "Failed to activate hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -68,7 +80,7 @@ export async function deactivateHook(id: number): Promise<HookResponse> {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to deactivate hook"));
|
||||
throw await parseError(res, "Failed to deactivate hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -78,7 +90,7 @@ export async function validateHook(id: number): Promise<HookValidateResponse> {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to validate hook"));
|
||||
throw await parseError(res, "Failed to validate hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
const route = ADMIN_ROUTES.LLM_MODELS;
|
||||
@@ -65,6 +66,7 @@ const PROVIDER_DISPLAY_ORDER: string[] = [
|
||||
"ollama_chat",
|
||||
"openrouter",
|
||||
"lm_studio",
|
||||
"bifrost",
|
||||
];
|
||||
|
||||
const PROVIDER_MODAL_MAP: Record<
|
||||
@@ -138,6 +140,13 @@ const PROVIDER_MODAL_MAP: Record<
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
bifrost: (d, open, onOpenChange) => (
|
||||
<BifrostModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import {
|
||||
AzureIcon,
|
||||
ElevenLabsIcon,
|
||||
@@ -19,12 +19,18 @@ import {
|
||||
import {
|
||||
activateVoiceProvider,
|
||||
deactivateVoiceProvider,
|
||||
deleteVoiceProvider,
|
||||
} from "@/lib/admin/voice/svc";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { SvgMicrophone } from "@opal/icons";
|
||||
import { SvgMicrophone, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import VoiceProviderSetupModal from "@/app/admin/configuration/voice/VoiceProviderSetupModal";
|
||||
|
||||
interface ModelDetails {
|
||||
@@ -129,10 +135,152 @@ function getProviderIcon(
|
||||
|
||||
type ProviderMode = "stt" | "tts";
|
||||
|
||||
function getProviderLabel(providerType: string): string {
|
||||
switch (providerType) {
|
||||
case "openai":
|
||||
return "OpenAI";
|
||||
case "azure":
|
||||
return "Azure";
|
||||
case "elevenlabs":
|
||||
return "ElevenLabs";
|
||||
default:
|
||||
return providerType;
|
||||
}
|
||||
}
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
const route = ADMIN_ROUTES.VOICE;
|
||||
const pageDescription =
|
||||
"Configure speech-to-text and text-to-speech providers for voice input and spoken responses.";
|
||||
|
||||
interface VoiceDisconnectModalProps {
|
||||
disconnectTarget: {
|
||||
providerId: number;
|
||||
providerLabel: string;
|
||||
providerType: string;
|
||||
};
|
||||
providers: VoiceProviderView[];
|
||||
replacementProviderId: string | null;
|
||||
onReplacementChange: (id: string | null) => void;
|
||||
onClose: () => void;
|
||||
onDisconnect: () => void;
|
||||
}
|
||||
|
||||
function VoiceDisconnectModal({
|
||||
disconnectTarget,
|
||||
providers,
|
||||
replacementProviderId,
|
||||
onReplacementChange,
|
||||
onClose,
|
||||
onDisconnect,
|
||||
}: VoiceDisconnectModalProps) {
|
||||
const targetProvider = providers.find(
|
||||
(p) => p.id === disconnectTarget.providerId
|
||||
);
|
||||
const isActive =
|
||||
(targetProvider?.is_default_stt ?? false) ||
|
||||
(targetProvider?.is_default_tts ?? false);
|
||||
|
||||
// Find other configured providers that could serve as replacements
|
||||
const replacementOptions = providers.filter(
|
||||
(p) => p.id !== disconnectTarget.providerId && p.has_api_key
|
||||
);
|
||||
|
||||
const needsReplacement = isActive;
|
||||
const hasReplacements = replacementOptions.length > 0;
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && hasReplacements && !replacementProviderId) {
|
||||
const first = replacementOptions[0];
|
||||
if (first) onReplacementChange(String(first.id));
|
||||
}
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectTarget.providerLabel}`}
|
||||
description="Voice models"
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<OpalButton
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</OpalButton>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
used for speech-to-text or text-to-speech, and it will no longer
|
||||
be your default. Session history will be preserved.
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" text04>
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => onReplacementChange(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement provider" />
|
||||
<InputSelect.Content>
|
||||
{replacementOptions.map((p) => (
|
||||
<InputSelect.Item
|
||||
key={p.id}
|
||||
value={String(p.id)}
|
||||
icon={getProviderIcon(p.provider_type)}
|
||||
>
|
||||
{getProviderLabel(p.provider_type)}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03"> (Disable Voice)</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
used for speech-to-text or text-to-speech, and it will no longer
|
||||
be your default.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Connect another provider to continue using voice.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
available for voice.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
);
|
||||
}
|
||||
|
||||
export default function VoiceConfigurationPage() {
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
|
||||
@@ -146,6 +294,14 @@ export default function VoiceConfigurationPage() {
|
||||
const [ttsActivationError, setTTSActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [disconnectTarget, setDisconnectTarget] = useState<{
|
||||
providerId: number;
|
||||
providerLabel: string;
|
||||
providerType: string;
|
||||
} | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
|
||||
const { providers, error, isLoading, refresh: mutate } = useVoiceProviders();
|
||||
|
||||
@@ -237,6 +393,65 @@ export default function VoiceConfigurationPage() {
|
||||
handleModalClose();
|
||||
};
|
||||
|
||||
const handleDisconnect = async () => {
|
||||
if (!disconnectTarget) return;
|
||||
try {
|
||||
const targetProvider = providers.find(
|
||||
(p) => p.id === disconnectTarget.providerId
|
||||
);
|
||||
|
||||
// If a replacement was selected (not "No Default"), activate it for each
|
||||
// mode the disconnected provider was default for
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
const repId = Number(replacementProviderId);
|
||||
|
||||
if (targetProvider?.is_default_stt) {
|
||||
const resp = await activateVoiceProvider(repId, "stt");
|
||||
if (!resp.ok) {
|
||||
const errorBody = await resp.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to activate replacement STT provider."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (targetProvider?.is_default_tts) {
|
||||
const resp = await activateVoiceProvider(repId, "tts");
|
||||
if (!resp.ok) {
|
||||
const errorBody = await resp.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to activate replacement TTS provider."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const response = await deleteVoiceProvider(disconnectTarget.providerId);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to disconnect provider."
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
toast.success(`${disconnectTarget.providerLabel} disconnected`);
|
||||
} catch (err) {
|
||||
console.error("Failed to disconnect voice provider:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Unexpected error occurred."
|
||||
);
|
||||
} finally {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const isProviderConfigured = (provider?: VoiceProviderView): boolean => {
|
||||
return !!provider?.has_api_key;
|
||||
};
|
||||
@@ -289,6 +504,16 @@ export default function VoiceConfigurationPage() {
|
||||
onEdit={() => {
|
||||
if (provider) handleEdit(provider, mode, model.id);
|
||||
}}
|
||||
onDisconnect={
|
||||
status !== "disconnected" && provider
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
providerId: provider.id,
|
||||
providerLabel: getProviderLabel(model.providerType),
|
||||
providerType: model.providerType,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -412,6 +637,20 @@ export default function VoiceConfigurationPage() {
|
||||
</div>
|
||||
</SettingsLayouts.Body>
|
||||
|
||||
{disconnectTarget && (
|
||||
<VoiceDisconnectModal
|
||||
disconnectTarget={disconnectTarget}
|
||||
providers={providers}
|
||||
replacementProviderId={replacementProviderId}
|
||||
onReplacementChange={setReplacementProviderId}
|
||||
onClose={() => {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
onDisconnect={() => void handleDisconnect()}
|
||||
/>
|
||||
)}
|
||||
|
||||
{modalOpen && selectedProvider && (
|
||||
<VoiceProviderSetupModal
|
||||
providerType={selectedProvider}
|
||||
|
||||
278
web/src/sections/modals/llmConfig/BifrostModal.tsx
Normal file
278
web/src/sections/modals/llmConfig/BifrostModal.tsx
Normal file
@@ -0,0 +1,278 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchBifrostModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const BIFROST_PROVIDER_NAME = LLMProviderName.BIFROST;
|
||||
const DEFAULT_API_BASE = "";
|
||||
|
||||
interface BifrostModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
interface BifrostModalInternalsProps {
|
||||
formikProps: FormikProps<BifrostModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function BifrostModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: BifrostModalInternalsProps) {
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
|
||||
const isFetchDisabled = !formikProps.values.api_base;
|
||||
|
||||
const handleFetchModels = async () => {
|
||||
const { models, error } = await fetchBifrostModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
useEffect(() => {
|
||||
if (existingLlmProvider && !isFetchDisabled) {
|
||||
handleFetchModels().catch((err) => {
|
||||
console.error("Failed to fetch Bifrost models:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
});
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.BIFROST}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="https://your-bifrost-gateway.com/v1"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_key"
|
||||
title="API Key"
|
||||
optional={true}
|
||||
subDescription={markdown(
|
||||
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
|
||||
)}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. anthropic/claude-sonnet-4-6" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
export default function BifrostModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
BIFROST_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: BifrostModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: BIFROST_PROVIDER_NAME,
|
||||
provider: BIFROST_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as BifrostModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<BifrostModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
|
||||
function detectIfRealOpenAIProvider(provider: LLMProviderView) {
|
||||
return (
|
||||
@@ -56,6 +57,8 @@ export function getModalForExistingProvider(
|
||||
return <LMStudioForm {...props} />;
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...props} />;
|
||||
case LLMProviderName.BIFROST:
|
||||
return <BifrostModal {...props} />;
|
||||
default:
|
||||
return <CustomModal {...props} />;
|
||||
}
|
||||
|
||||
246
web/tests/e2e/admin/image-generation/disconnect-provider.spec.ts
Normal file
246
web/tests/e2e/admin/image-generation/disconnect-provider.spec.ts
Normal file
@@ -0,0 +1,246 @@
|
||||
import { test, expect, Page, Locator } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
const IMAGE_GENERATION_URL = "/admin/configuration/image-generation";
|
||||
|
||||
const FAKE_CONNECTED_CONFIG = {
|
||||
image_provider_id: "openai_dalle_3",
|
||||
model_configuration_id: 100,
|
||||
model_name: "dall-e-3",
|
||||
llm_provider_id: 100,
|
||||
llm_provider_name: "openai-dalle3",
|
||||
is_default: false,
|
||||
};
|
||||
|
||||
const FAKE_DEFAULT_CONFIG = {
|
||||
image_provider_id: "openai_gpt_image_1",
|
||||
model_configuration_id: 101,
|
||||
model_name: "gpt-image-1",
|
||||
llm_provider_id: 101,
|
||||
llm_provider_name: "openai-gpt-image-1",
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
function getProviderCard(page: Page, providerId: string): Locator {
|
||||
return page.getByLabel(`image-gen-provider-${providerId}`, { exact: true });
|
||||
}
|
||||
|
||||
function mainContainer(page: Page): Locator {
|
||||
return page.locator("[data-main-container]");
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up route mocks so the page sees configured providers
|
||||
* without needing real API keys.
|
||||
*/
|
||||
async function mockImageGenApis(
|
||||
page: Page,
|
||||
configs: (typeof FAKE_CONNECTED_CONFIG)[]
|
||||
) {
|
||||
await page.route("**/api/admin/image-generation/config", async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({ status: 200, json: configs });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
|
||||
await page.route(
|
||||
"**/api/admin/llm/provider?include_image_gen=true",
|
||||
async (route) => {
|
||||
await route.fulfill({ status: 200, json: { providers: [] } });
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
test.describe("Image Generation Provider Disconnect", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
});
|
||||
|
||||
test("should disconnect a connected (non-default) provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
const configs = [{ ...FAKE_CONNECTED_CONFIG }, { ...FAKE_DEFAULT_CONFIG }];
|
||||
await mockImageGenApis(page, configs);
|
||||
|
||||
await page.goto(IMAGE_GENERATION_URL);
|
||||
await page.waitForSelector("text=Image Generation Model", {
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
const card = getProviderCard(page, "openai_dalle_3");
|
||||
await card.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "image-gen-disconnect-non-default-before",
|
||||
});
|
||||
|
||||
// Verify disconnect button exists and is enabled
|
||||
const disconnectButton = card.getByRole("button", {
|
||||
name: "Disconnect DALL-E 3",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
// Mock the DELETE to succeed and update the config list
|
||||
await page.route(
|
||||
"**/api/admin/image-generation/config/openai_dalle_3",
|
||||
async (route) => {
|
||||
if (route.request().method() === "DELETE") {
|
||||
// Update the GET mock to return only the default config
|
||||
await page.unroute("**/api/admin/image-generation/config");
|
||||
await page.route(
|
||||
"**/api/admin/image-generation/config",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
json: [{ ...FAKE_DEFAULT_CONFIG }],
|
||||
});
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
await route.fulfill({ status: 200, json: {} });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// Click disconnect
|
||||
await disconnectButton.click();
|
||||
|
||||
// Verify confirmation modal appears
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect DALL-E 3");
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "image-gen-disconnect-non-default-modal",
|
||||
});
|
||||
|
||||
// Click Disconnect in the confirmation modal
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await confirmButton.click();
|
||||
|
||||
// Verify the card reverts to disconnected state (shows "Connect" button)
|
||||
await expect(card.getByRole("button", { name: "Connect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "image-gen-disconnect-non-default-after",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show replacement dropdown when disconnecting default provider with alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
const configs = [{ ...FAKE_CONNECTED_CONFIG }, { ...FAKE_DEFAULT_CONFIG }];
|
||||
await mockImageGenApis(page, configs);
|
||||
|
||||
await page.goto(IMAGE_GENERATION_URL);
|
||||
await page.waitForSelector("text=Image Generation Model", {
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
const defaultCard = getProviderCard(page, "openai_gpt_image_1");
|
||||
await defaultCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// The disconnect button should be visible and enabled
|
||||
const disconnectButton = defaultCard.getByRole("button", {
|
||||
name: "Disconnect GPT Image 1",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Should show replacement dropdown since there's an alternative
|
||||
await expect(
|
||||
confirmDialog.getByText("Session history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled because first replacement is auto-selected
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "image-gen-disconnect-default-with-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show connect message when disconnecting default provider with no alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Only the default config — no other providers configured
|
||||
await mockImageGenApis(page, [{ ...FAKE_DEFAULT_CONFIG }]);
|
||||
|
||||
await page.goto(IMAGE_GENERATION_URL);
|
||||
await page.waitForSelector("text=Image Generation Model", {
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
const defaultCard = getProviderCard(page, "openai_gpt_image_1");
|
||||
await defaultCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = defaultCard.getByRole("button", {
|
||||
name: "Disconnect GPT Image 1",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Should show message about connecting another provider
|
||||
await expect(
|
||||
confirmDialog.getByText("Connect another provider")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "image-gen-disconnect-no-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should not show disconnect button for unconfigured providers", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockImageGenApis(page, [{ ...FAKE_DEFAULT_CONFIG }]);
|
||||
|
||||
await page.goto(IMAGE_GENERATION_URL);
|
||||
await page.waitForSelector("text=Image Generation Model", {
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
// DALL-E 3 is not configured — should not have a disconnect button
|
||||
const card = getProviderCard(page, "openai_dalle_3");
|
||||
await card.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = card.getByRole("button", {
|
||||
name: "Disconnect DALL-E 3",
|
||||
});
|
||||
await expect(disconnectButton).not.toBeVisible();
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "image-gen-disconnect-unconfigured",
|
||||
});
|
||||
});
|
||||
});
|
||||
317
web/tests/e2e/admin/voice/disconnect-provider.spec.ts
Normal file
317
web/tests/e2e/admin/voice/disconnect-provider.spec.ts
Normal file
@@ -0,0 +1,317 @@
|
||||
import { test, expect, Page, Locator } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
const VOICE_URL = "/admin/configuration/voice";
|
||||
|
||||
const FAKE_PROVIDERS = {
|
||||
openai_active_stt: {
|
||||
id: 1,
|
||||
name: "openai",
|
||||
provider_type: "openai",
|
||||
is_default_stt: true,
|
||||
is_default_tts: false,
|
||||
stt_model: "whisper",
|
||||
tts_model: null,
|
||||
default_voice: null,
|
||||
has_api_key: true,
|
||||
target_uri: null,
|
||||
},
|
||||
openai_active_both: {
|
||||
id: 1,
|
||||
name: "openai",
|
||||
provider_type: "openai",
|
||||
is_default_stt: true,
|
||||
is_default_tts: true,
|
||||
stt_model: "whisper",
|
||||
tts_model: "tts-1",
|
||||
default_voice: "alloy",
|
||||
has_api_key: true,
|
||||
target_uri: null,
|
||||
},
|
||||
openai_connected: {
|
||||
id: 1,
|
||||
name: "openai",
|
||||
provider_type: "openai",
|
||||
is_default_stt: false,
|
||||
is_default_tts: false,
|
||||
stt_model: null,
|
||||
tts_model: null,
|
||||
default_voice: null,
|
||||
has_api_key: true,
|
||||
target_uri: null,
|
||||
},
|
||||
elevenlabs_connected: {
|
||||
id: 2,
|
||||
name: "elevenlabs",
|
||||
provider_type: "elevenlabs",
|
||||
is_default_stt: false,
|
||||
is_default_tts: false,
|
||||
stt_model: null,
|
||||
tts_model: null,
|
||||
default_voice: null,
|
||||
has_api_key: true,
|
||||
target_uri: null,
|
||||
},
|
||||
};
|
||||
|
||||
function findModelCard(page: Page, ariaLabel: string): Locator {
|
||||
return page.getByLabel(ariaLabel, { exact: true });
|
||||
}
|
||||
|
||||
function mainContainer(page: Page): Locator {
|
||||
return page.locator("[data-main-container]");
|
||||
}
|
||||
|
||||
async function mockVoiceApis(
|
||||
page: Page,
|
||||
providers: (typeof FAKE_PROVIDERS)[keyof typeof FAKE_PROVIDERS][]
|
||||
) {
|
||||
await page.route("**/api/admin/voice/providers", async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({ status: 200, json: providers });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
test.describe("Voice Provider Disconnect", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
});
|
||||
|
||||
test("should disconnect a non-active provider and affect both STT and TTS cards", async ({
|
||||
page,
|
||||
}) => {
|
||||
const providers = [
|
||||
{ ...FAKE_PROVIDERS.openai_connected },
|
||||
{ ...FAKE_PROVIDERS.elevenlabs_connected },
|
||||
];
|
||||
await mockVoiceApis(page, providers);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-non-active-before",
|
||||
});
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
// Mock DELETE to succeed and remove OpenAI from provider list
|
||||
await page.route("**/api/admin/voice/providers/1", async (route) => {
|
||||
if (route.request().method() === "DELETE") {
|
||||
await page.unroute("**/api/admin/voice/providers");
|
||||
await page.route("**/api/admin/voice/providers", async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
json: [{ ...FAKE_PROVIDERS.elevenlabs_connected }],
|
||||
});
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
await route.fulfill({ status: 200, json: {} });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
// Modal shows provider name, not model name
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect OpenAI");
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "voice-disconnect-non-active-modal",
|
||||
});
|
||||
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await confirmButton.click();
|
||||
|
||||
// Both STT and TTS cards for OpenAI revert to disconnected
|
||||
await expect(
|
||||
whisperCard.getByRole("button", { name: "Connect" })
|
||||
).toBeVisible({ timeout: 10000 });
|
||||
|
||||
const tts1Card = findModelCard(page, "voice-tts-tts-1");
|
||||
await expect(tts1Card.getByRole("button", { name: "Connect" })).toBeVisible(
|
||||
{ timeout: 10000 }
|
||||
);
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-non-active-after",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show replacement dropdown when disconnecting active provider with alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// OpenAI is active for STT, ElevenLabs is also configured
|
||||
const providers = [
|
||||
{ ...FAKE_PROVIDERS.openai_active_stt },
|
||||
{ ...FAKE_PROVIDERS.elevenlabs_connected },
|
||||
];
|
||||
await mockVoiceApis(page, providers);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-active-with-alt-before",
|
||||
});
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect OpenAI");
|
||||
|
||||
// Should show replacement text and dropdown
|
||||
await expect(
|
||||
confirmDialog.getByText("Session history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled because first replacement is auto-selected
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "voice-disconnect-active-with-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show replacement when provider is default for both STT and TTS", async ({
|
||||
page,
|
||||
}) => {
|
||||
// OpenAI is default for both modes, ElevenLabs also configured
|
||||
const providers = [
|
||||
{ ...FAKE_PROVIDERS.openai_active_both },
|
||||
{ ...FAKE_PROVIDERS.elevenlabs_connected },
|
||||
];
|
||||
await mockVoiceApis(page, providers);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-both-modes-before",
|
||||
});
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect OpenAI");
|
||||
|
||||
// Should mention both modes
|
||||
await expect(
|
||||
confirmDialog.getByText("speech-to-text or text-to-speech")
|
||||
).toBeVisible();
|
||||
|
||||
// Should show replacement dropdown
|
||||
await expect(
|
||||
confirmDialog.getByText("Session history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "voice-disconnect-both-modes-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show connect message when disconnecting active provider with no alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Only OpenAI configured, active for STT — no other providers
|
||||
const providers = [{ ...FAKE_PROVIDERS.openai_active_stt }];
|
||||
await mockVoiceApis(page, providers);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-no-alt-before",
|
||||
});
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect OpenAI");
|
||||
|
||||
// Should show message about connecting another provider
|
||||
await expect(
|
||||
confirmDialog.getByText("Connect another provider")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "voice-disconnect-no-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should not show disconnect button for unconfigured provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockVoiceApis(page, []);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await expect(disconnectButton).not.toBeVisible();
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-unconfigured",
|
||||
});
|
||||
});
|
||||
});
|
||||
394
web/tests/e2e/admin/web-search/disconnect-provider.spec.ts
Normal file
394
web/tests/e2e/admin/web-search/disconnect-provider.spec.ts
Normal file
@@ -0,0 +1,394 @@
|
||||
import { test, expect, Page, Locator } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
const WEB_SEARCH_URL = "/admin/configuration/web-search";
|
||||
|
||||
const FAKE_SEARCH_PROVIDERS = {
|
||||
exa: {
|
||||
id: 1,
|
||||
name: "Exa",
|
||||
provider_type: "exa",
|
||||
is_active: true,
|
||||
config: null,
|
||||
has_api_key: true,
|
||||
},
|
||||
brave: {
|
||||
id: 2,
|
||||
name: "Brave",
|
||||
provider_type: "brave",
|
||||
is_active: false,
|
||||
config: null,
|
||||
has_api_key: true,
|
||||
},
|
||||
};
|
||||
|
||||
const FAKE_CONTENT_PROVIDERS = {
|
||||
firecrawl: {
|
||||
id: 10,
|
||||
name: "Firecrawl",
|
||||
provider_type: "firecrawl",
|
||||
is_active: true,
|
||||
config: { base_url: "https://api.firecrawl.dev/v2/scrape" },
|
||||
has_api_key: true,
|
||||
},
|
||||
exa: {
|
||||
id: 11,
|
||||
name: "Exa",
|
||||
provider_type: "exa",
|
||||
is_active: false,
|
||||
config: null,
|
||||
has_api_key: true,
|
||||
},
|
||||
};
|
||||
|
||||
function findProviderCard(page: Page, providerLabel: string): Locator {
|
||||
return page
|
||||
.locator("div.rounded-16")
|
||||
.filter({ hasText: providerLabel })
|
||||
.first();
|
||||
}
|
||||
|
||||
function mainContainer(page: Page): Locator {
|
||||
return page.locator("[data-main-container]");
|
||||
}
|
||||
|
||||
async function mockWebSearchApis(
|
||||
page: Page,
|
||||
searchProviders: (typeof FAKE_SEARCH_PROVIDERS)[keyof typeof FAKE_SEARCH_PROVIDERS][],
|
||||
contentProviders: (typeof FAKE_CONTENT_PROVIDERS)[keyof typeof FAKE_CONTENT_PROVIDERS][]
|
||||
) {
|
||||
await page.route(
|
||||
"**/api/admin/web-search/search-providers",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({ status: 200, json: searchProviders });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await page.route(
|
||||
"**/api/admin/web-search/content-providers",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({ status: 200, json: contentProviders });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
test.describe("Web Search Provider Disconnect", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
});
|
||||
|
||||
test.describe("Search Engine Providers", () => {
|
||||
test("should disconnect a connected (non-active) search provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
const searchProviders = [
|
||||
{ ...FAKE_SEARCH_PROVIDERS.exa },
|
||||
{ ...FAKE_SEARCH_PROVIDERS.brave },
|
||||
];
|
||||
await mockWebSearchApis(page, searchProviders, []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
|
||||
|
||||
const braveCard = findProviderCard(page, "Brave");
|
||||
await braveCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "web-search-disconnect-non-active-before",
|
||||
});
|
||||
|
||||
const disconnectButton = braveCard.getByRole("button", {
|
||||
name: "Disconnect Brave",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
// Mock the DELETE to succeed
|
||||
await page.route(
|
||||
"**/api/admin/web-search/search-providers/2",
|
||||
async (route) => {
|
||||
if (route.request().method() === "DELETE") {
|
||||
await page.unroute("**/api/admin/web-search/search-providers");
|
||||
await page.route(
|
||||
"**/api/admin/web-search/search-providers",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
json: [{ ...FAKE_SEARCH_PROVIDERS.exa }],
|
||||
});
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
await route.fulfill({ status: 200, json: {} });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect Brave");
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-non-active-modal",
|
||||
});
|
||||
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await confirmButton.click();
|
||||
|
||||
await expect(
|
||||
braveCard.getByRole("button", { name: "Connect" })
|
||||
).toBeVisible({ timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "web-search-disconnect-non-active-after",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show replacement dropdown when disconnecting active search provider with alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Exa is active, Brave is also configured
|
||||
const searchProviders = [
|
||||
{ ...FAKE_SEARCH_PROVIDERS.exa },
|
||||
{ ...FAKE_SEARCH_PROVIDERS.brave },
|
||||
];
|
||||
await mockWebSearchApis(page, searchProviders, []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
|
||||
|
||||
const exaCard = findProviderCard(page, "Exa");
|
||||
await exaCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = exaCard.getByRole("button", {
|
||||
name: "Disconnect Exa",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect Exa");
|
||||
|
||||
// Should show replacement dropdown
|
||||
await expect(
|
||||
confirmDialog.getByText("Search history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled because first replacement is auto-selected
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-active-with-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show connect message when disconnecting active search provider with no alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Only Exa configured and active
|
||||
await mockWebSearchApis(page, [{ ...FAKE_SEARCH_PROVIDERS.exa }], []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
|
||||
|
||||
const exaCard = findProviderCard(page, "Exa");
|
||||
await exaCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = exaCard.getByRole("button", {
|
||||
name: "Disconnect Exa",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Should show message about connecting another provider
|
||||
await expect(
|
||||
confirmDialog.getByText("Connect another provider")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-no-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should not show disconnect button for unconfigured search provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockWebSearchApis(page, [{ ...FAKE_SEARCH_PROVIDERS.exa }], []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
|
||||
|
||||
const braveCard = findProviderCard(page, "Brave");
|
||||
await braveCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = braveCard.getByRole("button", {
|
||||
name: "Disconnect Brave",
|
||||
});
|
||||
await expect(disconnectButton).not.toBeVisible();
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "web-search-disconnect-unconfigured",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Web Crawler (Content) Providers", () => {
|
||||
test("should disconnect a connected (non-active) content provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Firecrawl connected but not active, Exa is active
|
||||
const contentProviders = [
|
||||
{ ...FAKE_CONTENT_PROVIDERS.firecrawl, is_active: false },
|
||||
{ ...FAKE_CONTENT_PROVIDERS.exa, is_active: true },
|
||||
];
|
||||
await mockWebSearchApis(page, [], contentProviders);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
|
||||
|
||||
const firecrawlCard = findProviderCard(page, "Firecrawl");
|
||||
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = firecrawlCard.getByRole("button", {
|
||||
name: "Disconnect Firecrawl",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
// Mock the DELETE to succeed
|
||||
await page.route(
|
||||
"**/api/admin/web-search/content-providers/10",
|
||||
async (route) => {
|
||||
if (route.request().method() === "DELETE") {
|
||||
await page.unroute("**/api/admin/web-search/content-providers");
|
||||
await page.route(
|
||||
"**/api/admin/web-search/content-providers",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
json: [{ ...FAKE_CONTENT_PROVIDERS.exa, is_active: true }],
|
||||
});
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
await route.fulfill({ status: 200, json: {} });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect Firecrawl");
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-content-non-active-modal",
|
||||
});
|
||||
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await confirmButton.click();
|
||||
|
||||
await expect(
|
||||
firecrawlCard.getByRole("button", { name: "Connect" })
|
||||
).toBeVisible({ timeout: 10000 });
|
||||
});
|
||||
|
||||
test("should show replacement dropdown when disconnecting active content provider with alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Firecrawl is active, Exa is also configured
|
||||
const contentProviders = [
|
||||
{ ...FAKE_CONTENT_PROVIDERS.firecrawl },
|
||||
{ ...FAKE_CONTENT_PROVIDERS.exa },
|
||||
];
|
||||
await mockWebSearchApis(page, [], contentProviders);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
|
||||
|
||||
const firecrawlCard = findProviderCard(page, "Firecrawl");
|
||||
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = firecrawlCard.getByRole("button", {
|
||||
name: "Disconnect Firecrawl",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Should show replacement dropdown
|
||||
await expect(
|
||||
confirmDialog.getByText("Search history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect should be enabled because first replacement is auto-selected
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-content-active-with-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should not show disconnect for Onyx Web Crawler (built-in)", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockWebSearchApis(page, [], []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
|
||||
|
||||
const onyxCard = findProviderCard(page, "Onyx Web Crawler");
|
||||
await onyxCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = onyxCard.getByRole("button", {
|
||||
name: "Disconnect Onyx Web Crawler",
|
||||
});
|
||||
await expect(disconnectButton).not.toBeVisible();
|
||||
});
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user