Compare commits

..

17 Commits

Author SHA1 Message Date
Bo-Onyx
c9f59aad42 feat(hook): admin page create or edit hook 2026-03-26 18:49:02 -07:00
Jamison Lahman
b9e84c42a8 feat(providers): allow deleting all types of providers (#9625)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-03-26 15:20:56 -07:00
Bo-Onyx
0a1df52c2f feat(hook): Hook Form Modal Polish. (#9683) 2026-03-26 22:12:33 +00:00
Nikolas Garza
306b0d452f fix(billing): retry claimLicense up to 3x after Stripe checkout return (#9669) 2026-03-26 21:06:19 +00:00
Justin Tahara
5fdb34ba8e feat(llm): add Bifrost gateway frontend modal and provider registration (#9617) 2026-03-26 20:50:25 +00:00
Jamison Lahman
2d066631e3 fix(voice): dont soft-delete providers (#9679) 2026-03-26 19:26:32 +00:00
Evan Lohn
5c84f6c61b fix(jira): large batches fail json decode (#9677) 2026-03-26 18:53:37 +00:00
Nikolas Garza
899179d4b6 fix(api-key): clarify upgrade message for trial accounts (#9678) 2026-03-26 18:32:41 +00:00
Bo-Onyx
80d6bafc74 feat(hook): Hook connect/manage modal (#9645) 2026-03-26 18:16:33 +00:00
Nikolas Garza
2cc325cb0e chore(greptile): split greptile.json into .greptile/ directory (#9668) 2026-03-26 17:05:43 +00:00
Raunak Bhagat
849385b756 refactor: migrate legacy components/Text (#9628) 2026-03-26 16:14:03 +00:00
Ben Wu
417b9c12e4 feat(canvas): add API client, data models, and connector scaffold 1/6 (#9385)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-26 15:26:52 +00:00
Raunak Bhagat
30b37d0a77 fix(admin): wrap system prompt modal in Formik with markdown subDescription (#9667) 2026-03-26 07:08:56 -07:00
Justin Tahara
b48be0cd3a feat(llm): add Bifrost gateway as LLM provider (backend) (#9616) 2026-03-26 05:09:20 +00:00
Nikolas Garza
127fd90424 fix(metrics): replace inspect.ping() with event-based worker health monitoring (#9633) 2026-03-26 03:36:07 +00:00
Raunak Bhagat
f9c9e55f32 refactor(opal): accept string | RichStr in all Opal text-rendering components, modals, and input-layouts (#9656) 2026-03-26 02:46:34 +00:00
Raunak Bhagat
5afcf1acea fix(opal): remove gap between title and description in ContentMd (#9666) 2026-03-25 19:45:21 -07:00
118 changed files with 6962 additions and 3338 deletions

64
.greptile/config.json Normal file
View File

@@ -0,0 +1,64 @@
{
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 3,
"statusCheck": true,
"commentTypes": [
"logic",
"syntax",
"style"
],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": [
"dependabot[bot]",
"renovate[bot]"
],
"ignoreKeywords": "",
"ignorePatterns": "",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": true,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"rules": [
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
]
}

57
.greptile/files.json Normal file
View File

@@ -0,0 +1,57 @@
[
{
"scope": [],
"path": "contributing_guides/best_practices.md",
"description": "Best practices for contributing to the codebase"
},
{
"scope": ["web/**"],
"path": "web/AGENTS.md",
"description": "Frontend coding standards for the web directory"
},
{
"scope": ["web/**"],
"path": "web/tests/README.md",
"description": "Frontend testing guide and conventions"
},
{
"scope": ["web/**"],
"path": "web/CLAUDE.md",
"description": "Single source of truth for frontend coding standards"
},
{
"scope": ["web/**"],
"path": "web/lib/opal/README.md",
"description": "Opal component library usage guide"
},
{
"scope": ["backend/**"],
"path": "backend/tests/README.md",
"description": "Backend testing guide covering all 4 test types, fixtures, and conventions"
},
{
"scope": ["backend/onyx/connectors/**"],
"path": "backend/onyx/connectors/README.md",
"description": "Connector development guide covering design, interfaces, and required changes"
},
{
"scope": [],
"path": "CLAUDE.md",
"description": "Project instructions and coding standards"
},
{
"scope": [],
"path": "backend/alembic/README.md",
"description": "Migration guidance, including multi-tenant migration behavior"
},
{
"scope": [],
"path": "deployment/helm/charts/onyx/values-lite.yaml",
"description": "Lite deployment Helm values and service assumptions"
},
{
"scope": [],
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
}
]

29
.greptile/rules.md Normal file
View File

@@ -0,0 +1,29 @@
# Greptile Review Rules
## Type Annotations
Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code.
## Best Practices
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
## TODOs
Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of `TODO(name): ...` or `TODO(1234): ...`
## Debugging Code
Remove temporary debugging code before merging to production, especially tenant-specific debugging logs.
## Hardcoded Booleans
When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant.
## Multi-tenant vs Single-tenant
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
## Full vs Lite Deployments
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.

View File

@@ -0,0 +1,35 @@
"""remove voice_provider deleted column
Revision ID: 1d78c0ca7853
Revises: a3f8b2c1d4e5
Create Date: 2026-03-26 11:30:53.883127
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "1d78c0ca7853"
down_revision = "a3f8b2c1d4e5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Hard-delete any soft-deleted rows before dropping the column
op.execute("DELETE FROM voice_provider WHERE deleted = true")
op.drop_column("voice_provider", "deleted")
def downgrade() -> None:
op.add_column(
"voice_provider",
sa.Column(
"deleted",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)

View File

@@ -1,8 +1,19 @@
import threading
import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
@@ -148,3 +159,114 @@ class ChatStateContainer:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
with event streaming capabilities.
The wrapped function should accept emitter as first arg and use it to emit
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
)
for packet in packets:
# Process packets
pass
"""
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)
# Run the function in a background thread
thread = run_in_background(run_with_exception_capture)
pkt: Packet | None = None
last_turn_index = 0 # Track the highest turn_index seen for stop packet
last_cancel_check = time.monotonic()
cancel_check_interval = 0.3 # Check for cancellation every 300ms
try:
while True:
# Poll queue with 300ms timeout for natural stop signal checking
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
try:
pkt = emitter.bus.get(timeout=0.3)
except Empty:
if not is_connected():
# Stop signal detected
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = time.monotonic()
continue
if pkt is not None:
# Track the highest turn_index for the stop packet
if pkt.placement and pkt.placement.turn_index > last_turn_index:
last_turn_index = pkt.placement.turn_index
if isinstance(pkt.obj, OverallStop):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
# Check for cancellation periodically even when packets are flowing
# This ensures stop signal is checked during active streaming
current_time = time.monotonic()
if current_time - last_cancel_check >= cancel_check_interval:
if not is_connected():
# Stop signal detected during streaming
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = current_time
finally:
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -1,80 +1,19 @@
import queue
from queue import Queue
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Routes packets produced during tool and LLM execution to the right destination.
"""Use this inside tools to emit arbitrary UI progress."""
Operates in one of two modes determined by whether ``merged_queue`` is supplied:
**Standalone** (no ``merged_queue``): packets land on ``self.bus``. Used by tests,
custom tools, and any caller that reads the emitter directly after execution.
**Streaming** (``merged_queue`` provided): packets are tagged with ``model_index``
and placed as ``(key, packet)`` tuples on the shared queue for the
``_run_models`` drain loop to consume and yield downstream.
Attributes:
bus: Fallback queue for standalone mode. Always created so existing callers
(tests, eval harnesses, custom-tool scripts) work without modification.
Args:
model_idx: Index embedded in packet placements. Pass ``None`` for single-model
runs to preserve the backwards-compatible wire format (``model_index=None``
in the packet); pass an integer for each model in a multi-model run.
merged_queue: Shared queue owned by the ``_run_models`` drain loop. When set,
all ``emit()`` calls route here instead of ``self.bus``.
Example::
# Standalone — read from bus after the fact (tests, evals)
emitter = Emitter()
emitter.emit(packet)
result = emitter.bus.get()
# Streaming — wired into _run_models (production path)
emitter = Emitter(model_idx=0, merged_queue=merged_queue)
emitter.emit(packet) # places (0, tagged_packet) on merged_queue
"""
def __init__(
self,
model_idx: int | None = None,
merged_queue: "queue.Queue | None" = None,
) -> None:
self._model_idx = model_idx
self._merged_queue = merged_queue
# Always created for backwards compatibility (tests, custom_tool, customer scripts, etc.)
self.bus: Queue[Packet] = Queue()
def __init__(self, bus: Queue):
self.bus = bus
def emit(self, packet: Packet) -> None:
"""Emit a packet, routing it to the merged queue or the local bus.
In streaming mode, stamps the packet's placement with ``model_index`` before
forwarding so the drain loop can attribute it to the correct model. In
standalone mode, places the packet on ``self.bus`` unchanged.
Args:
packet: The packet to emit.
"""
if self._merged_queue is not None:
tagged_placement = Placement(
turn_index=packet.placement.turn_index if packet.placement else 0,
tab_index=packet.placement.tab_index if packet.placement else 0,
sub_turn_index=(
packet.placement.sub_turn_index if packet.placement else None
),
model_index=self._model_idx,
)
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
key = self._model_idx if self._model_idx is not None else 0
self._merged_queue.put((key, tagged_packet))
else:
self.bus.put(packet)
self.bus.put(packet) # Thread-safe
def get_default_emitter() -> Emitter:
return Emitter()
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,192 @@
from __future__ import annotations
import logging
import re
from typing import Any
from urllib.parse import urlparse
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
)
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
logger = logging.getLogger(__name__)
# Requests timeout in seconds.
_CANVAS_CALL_TIMEOUT: int = 30
_CANVAS_API_VERSION: str = "/api/v1"
# Matches the "next" URL in a Canvas Link header, e.g.:
# <https://canvas.example.com/api/v1/courses?page=2>; rel="next"
# Captures the URL inside the angle brackets.
_NEXT_LINK_PATTERN: re.Pattern[str] = re.compile(r'<([^>]+)>;\s*rel="next"')
_STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
404: OnyxErrorCode.BAD_GATEWAY,
429: OnyxErrorCode.RATE_LIMITED,
}
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
"""Map an HTTP status code to the appropriate OnyxErrorCode.
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
mapped to specific error codes; all other codes (unrecognised 4xx
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
"""
if status_code in _STATUS_TO_ERROR_CODE:
return _STATUS_TO_ERROR_CODE[status_code]
return OnyxErrorCode.BAD_GATEWAY
class CanvasApiClient:
def __init__(
self,
bearer_token: str,
canvas_base_url: str,
) -> None:
parsed_base = urlparse(canvas_base_url)
if not parsed_base.hostname:
raise ValueError("canvas_base_url must include a valid host")
if parsed_base.scheme != "https":
raise ValueError("canvas_base_url must use https")
self._bearer_token = bearer_token
self.base_url = (
canvas_base_url.rstrip("/").removesuffix(_CANVAS_API_VERSION)
+ _CANVAS_API_VERSION
)
# Hostname is already validated above; reuse parsed_base instead
# of re-parsing. Used by _parse_next_link to validate pagination URLs.
self._expected_host: str = parsed_base.hostname
def get(
self,
endpoint: str = "",
params: dict[str, Any] | None = None,
full_url: str | None = None,
) -> tuple[Any, str | None]:
"""Make a GET request to the Canvas API.
Returns a tuple of (json_body, next_url).
next_url is parsed from the Link header and is None if there are no more pages.
If full_url is provided, it is used directly (for following pagination links).
Security note: full_url must only be set to values returned by
``_parse_next_link``, which validates the host against the configured
Canvas base URL. Passing an arbitrary URL would leak the bearer token.
"""
# full_url is used when following pagination (Canvas returns the
# next-page URL in the Link header). For the first request we build
# the URL from the endpoint name instead.
url = full_url if full_url else self._build_url(endpoint)
headers = self._build_headers()
response = rl_requests.get(
url,
headers=headers,
params=params if not full_url else None,
timeout=_CANVAS_CALL_TIMEOUT,
)
try:
response_json = response.json()
except ValueError as e:
if response.status_code < 300:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=f"Invalid JSON in Canvas response: {e}",
)
logger.warning(
"Failed to parse JSON from Canvas error response (status=%d): %s",
response.status_code,
e,
)
response_json = {}
if response.status_code >= 400:
# Try to extract the most specific error message from the
# Canvas response body. Canvas uses three different shapes
# depending on the endpoint and error type:
default_error: str = response.reason or f"HTTP {response.status_code}"
error = default_error
if isinstance(response_json, dict):
# Shape 1: {"error": {"message": "Not authorized"}}
error_field = response_json.get("error")
if isinstance(error_field, dict):
response_error = error_field.get("message", "")
if response_error:
error = response_error
# Shape 2: {"error": "Invalid access token"}
elif isinstance(error_field, str):
error = error_field
# Shape 3: {"errors": [{"message": "..."}]}
# Used for validation errors. Only use as fallback if
# we didn't already find a more specific message above.
if error == default_error:
errors_list = response_json.get("errors")
if isinstance(errors_list, list) and errors_list:
first_error = errors_list[0]
if isinstance(first_error, dict):
msg = first_error.get("message", "")
if msg:
error = msg
raise OnyxError(
_error_code_for_status(response.status_code),
detail=error,
status_code_override=response.status_code,
)
next_url = self._parse_next_link(response.headers.get("Link", ""))
return response_json, next_url
def _parse_next_link(self, link_header: str) -> str | None:
"""Extract the 'next' URL from a Canvas Link header.
Only returns URLs whose host matches the configured Canvas base URL
to prevent leaking the bearer token to arbitrary hosts.
"""
expected_host = self._expected_host
for match in _NEXT_LINK_PATTERN.finditer(link_header):
url = match.group(1)
parsed_url = urlparse(url)
if parsed_url.hostname != expected_host:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=(
"Canvas pagination returned an unexpected host "
f"({parsed_url.hostname}); expected {expected_host}"
),
)
if parsed_url.scheme != "https":
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
detail=(
"Canvas pagination link must use https, "
f"got {parsed_url.scheme!r}"
),
)
return url
return None
def _build_headers(self) -> dict[str, str]:
"""Return the Authorization header with the bearer token."""
return {"Authorization": f"Bearer {self._bearer_token}"}
def _build_url(self, endpoint: str) -> str:
"""Build a full Canvas API URL from an endpoint path.
Assumes endpoint is non-empty (e.g. ``"courses"``, ``"announcements"``).
Only called on a first request, endpoint must be set for first request.
Verify endpoint exists in case of future changes where endpoint might be optional.
Leading slashes are stripped to avoid double-slash in the result.
self.base_url is already normalized with no trailing slash.
"""
final_url = self.base_url
clean_endpoint = endpoint.lstrip("/")
if clean_endpoint:
final_url += "/" + clean_endpoint
return final_url

View File

@@ -0,0 +1,74 @@
from typing import Literal
from typing import TypeAlias
from pydantic import BaseModel
from onyx.connectors.models import ConnectorCheckpoint
class CanvasCourse(BaseModel):
id: int
name: str
course_code: str
created_at: str
workflow_state: str
class CanvasPage(BaseModel):
page_id: int
url: str
title: str
body: str | None = None
created_at: str
updated_at: str
course_id: int
class CanvasAssignment(BaseModel):
id: int
name: str
description: str | None = None
html_url: str
course_id: int
created_at: str
updated_at: str
due_at: str | None = None
class CanvasAnnouncement(BaseModel):
id: int
title: str
message: str | None = None
html_url: str
posted_at: str | None = None
course_id: int
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
"""Checkpoint state for resumable Canvas indexing.
Fields:
course_ids: Materialized list of course IDs to process.
current_course_index: Index into course_ids for current course.
stage: Which item type we're processing for the current course.
next_url: Pagination cursor within the current stage. None means
start from the first page; a URL means resume from that page.
Invariant:
If current_course_index is incremented, stage must be reset to
"pages" and next_url must be reset to None.
"""
course_ids: list[int] = []
current_course_index: int = 0
stage: CanvasStage = "pages"
next_url: str | None = None
def advance_course(self) -> None:
"""Move to the next course and reset within-course state."""
self.current_course_index += 1
self.stage = "pages"
self.next_url = None

View File

@@ -10,6 +10,7 @@ from datetime import timedelta
from datetime import timezone
from typing import Any
import requests
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
@@ -239,29 +240,53 @@ def enhanced_search_ids(
)
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO: move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
def _bulk_fetch_request(
jira_client: JIRA, issue_ids: list[str], fields: str | None
) -> list[dict[str, Any]]:
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
# Prepare the payload according to Jira API v3 specification
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
# Only restrict fields if specified, might want to explicitly do this in the future
# to avoid reading unnecessary data
payload["fields"] = fields.split(",") if fields else ["*all"]
resp = jira_client._session.post(bulk_fetch_path, json=payload)
return resp.json()["issues"]
def bulk_fetch_issues(
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
) -> list[Issue]:
# TODO(evan): move away from this jira library if they continue to not support
# the endpoints we need. Using private fields is not ideal, but
# is likely fine for now since we pin the library version
try:
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
except requests.exceptions.JSONDecodeError:
if len(issue_ids) <= 1:
logger.exception(
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
f"be decoded as JSON (response too large or truncated)."
)
raise
mid = len(issue_ids) // 2
logger.warning(
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
)
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
return left + right
except Exception as e:
logger.error(f"Error fetching issues: {e}")
raise e
raise
return [
Issue(jira_client._options, jira_client._session, raw=issue)
for issue in response["issues"]
for issue in raw_issues
]

View File

@@ -617,92 +617,6 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -925,8 +839,6 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -3135,8 +3135,6 @@ class VoiceProvider(Base):
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(
select(VoiceProvider)
.where(VoiceProvider.deleted.is_(False))
.order_by(VoiceProvider.name)
).all()
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int, include_deleted: bool = False
db_session: Session, provider_id: int
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
if not include_deleted:
stmt = stmt.where(VoiceProvider.deleted.is_(False))
return db_session.scalar(stmt)
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.is_default_stt.is_(True))
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.is_default_tts.is_(True))
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
)
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.provider_type == provider_type)
.where(VoiceProvider.deleted.is_(False))
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
)
@@ -119,10 +108,10 @@ def upsert_voice_provider(
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Soft-delete a voice provider by ID."""
"""Delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
provider.deleted = True
db_session.delete(provider)
db_session.flush()

View File

@@ -25,6 +25,7 @@ class LlmProviderNames(str, Enum):
LM_STUDIO = "lm_studio"
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
BIFROST = "bifrost"
def __str__(self) -> str:
"""Needed so things like:
@@ -44,6 +45,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
]
@@ -61,6 +63,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: "Ollama",
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
LlmProviderNames.BIFROST: "Bifrost",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -112,6 +115,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.VERTEX_AI,
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
}
# Model family name mappings for display name generation

View File

@@ -290,6 +290,17 @@ class LitellmLLM(LLM):
):
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
# Bifrost: OpenAI-compatible proxy that expects model names in
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
# We route through LiteLLM's openai provider with the Bifrost base URL,
# and ensure /v1 is appended.
if model_provider == LlmProviderNames.BIFROST:
self._custom_llm_provider = "openai"
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
model_kwargs["api_base"] = self._api_base
# This is needed for Ollama to do proper function calling
if model_provider == LlmProviderNames.OLLAMA_CHAT and api_base is not None:
model_kwargs["api_base"] = api_base
@@ -401,14 +412,20 @@ class LitellmLLM(LLM):
optional_kwargs: dict[str, Any] = {}
# Model name
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
model_provider = (
f"{self.config.model_provider}/responses"
if is_openai_model # Uses litellm's completions -> responses bridge
else self.config.model_provider
)
model = (
f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
)
if is_bifrost:
# Bifrost expects model names in provider/model format
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
# so LiteLLM doesn't try to route based on the provider prefix.
model = self.config.deployment_name or self.config.model_name
else:
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
# Tool choice
if is_claude_model and tool_choice == ToolChoiceOptions.REQUIRED:
@@ -483,10 +500,11 @@ class LitellmLLM(LLM):
if structured_response_format:
optional_kwargs["response_format"] = structured_response_format
if not (is_claude_model or is_ollama or is_mistral):
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
# However, this param breaks Anthropic and Mistral models,
# so it must be conditionally included.
# so it must be conditionally included unless the request is
# routed through Bifrost's OpenAI-compatible endpoint.
# Additionally, tool_choice is not supported by Ollama and causes warnings if included.
# See also, https://github.com/ollama/ollama/issues/11171
optional_kwargs["allowed_openai_params"] = ["tool_choice"]

View File

@@ -8,24 +8,6 @@ from pydantic import BaseModel
class LLMOverride(BaseModel):
"""Per-request LLM settings that override persona defaults.
All fields are optional — only the fields that differ from the persona's
configured LLM need to be supplied. Used both over the wire (API requests)
and for multi-model comparison, where one override is supplied per model.
Attributes:
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
When ``None``, the persona's default provider is used.
model_version: Specific model version string (e.g. ``"gpt-4o"``).
When ``None``, the persona's default model is used.
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
persona's default temperature is used.
display_name: Human-readable label shown in the UI for this model,
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
when not set.
"""
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None

View File

@@ -13,6 +13,8 @@ LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
BIFROST_PROVIDER_NAME = "bifrost"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
@@ -49,6 +50,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
}

View File

@@ -44,11 +44,12 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
# ---------------------------------------------------------------------------
@@ -122,9 +123,8 @@ def _validate_endpoint(
(not reachable — indicates the api_key is invalid).
Timeout handling:
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
timeout (operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)
@@ -141,19 +141,11 @@ def _validate_endpoint(
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
# Any timeout (connect, read, or write) means the configured timeout_seconds
# is too low for this endpoint. Report as timeout so the UI directs the user
# to increase the timeout setting.
logger.warning(
"Hook endpoint validation: read/write timeout for %s",
"Hook endpoint validation: timeout for %s",
endpoint_url,
exc_info=exc,
)

View File

@@ -57,6 +57,8 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import BifrostFinalModelResponse
from onyx.server.manage.llm.models import BifrostModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelDetails
@@ -1422,11 +1424,26 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="LiteLLM proxy",
api_key=api_key,
)
def _get_openai_compatible_models_response(
url: str,
source_name: str,
api_key: str | None = None,
) -> dict:
"""Fetch model metadata from an OpenAI-compatible `/models` endpoint."""
headers = {
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
if not api_key:
headers.pop("Authorization")
try:
response = httpx.get(url, headers=headers, timeout=10.0)
@@ -1436,20 +1453,125 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
if e.response.status_code == 401:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
f"Authentication failed: invalid or missing API key for {source_name}.",
)
elif e.response.status_code == 404:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"LiteLLM models endpoint not found at {url}. Please verify the API base URL.",
f"{source_name} models endpoint not found at {url}. Please verify the API base URL.",
)
else:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
f"Failed to fetch {source_name} models: {e}",
)
except Exception as e:
except httpx.RequestError as e:
logger.warning(
"Failed to fetch models from OpenAI-compatible endpoint",
extra={"source": source_name, "url": url, "error": str(e)},
exc_info=True,
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
f"Failed to fetch {source_name} models: {e}",
)
except ValueError as e:
logger.warning(
"Received invalid model response from OpenAI-compatible endpoint",
extra={"source": source_name, "url": url, "error": str(e)},
exc_info=True,
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch {source_name} models: {e}",
)
@admin_router.post("/bifrost/available-models")
def get_bifrost_available_models(
request: BifrostModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[BifrostFinalModelResponse]:
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
response_json = _get_bifrost_models_response(
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Bifrost endpoint",
)
results: list[BifrostFinalModelResponse] = []
for model in models:
try:
model_id = model.get("id", "")
model_name = model.get("name", model_id)
if not model_id:
continue
# Skip embedding models
if is_embedding_model(model_id):
continue
results.append(
BifrostFinalModelResponse(
name=model_id,
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:
logger.warning(
"Failed to parse Bifrost model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from Bifrost",
)
sorted_results = sorted(results, key=lambda m: m.name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="Bifrost",
)
return sorted_results
def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> dict:
"""Perform GET to Bifrost /v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
# Ensure we hit /v1/models
if cleaned_api_base.endswith("/v1"):
url = f"{cleaned_api_base}/models"
else:
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="Bifrost",
api_key=api_key,
)

View File

@@ -449,3 +449,18 @@ class LitellmModelDetails(BaseModel):
class LitellmFinalModelResponse(BaseModel):
provider_name: str # Provider name (e.g. "openai")
model_name: str # Model ID (e.g. "gpt-4o")
# Bifrost dynamic models fetch
class BifrostModelsRequest(BaseModel):
api_base: str
api_key: str | None = None
provider_name: str | None = None # Optional: to save models to existing provider
class BifrostFinalModelResponse(BaseModel):
name: str # Model ID in provider/model format (e.g. "anthropic/claude-sonnet-4-6")
display_name: str # Human-readable name from Bifrost API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -25,6 +25,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
LlmProviderNames.BEDROCK,
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.BIFROST,
}
)
@@ -50,6 +51,25 @@ BEDROCK_VISION_MODELS = frozenset(
}
)
# Known Bifrost/OpenAI-compatible vision-capable model families where the
# source API does not expose this metadata directly.
BIFROST_VISION_MODEL_FAMILIES = frozenset(
{
"anthropic/claude-3",
"anthropic/claude-4",
"amazon/nova-pro",
"amazon/nova-lite",
"amazon/nova-premier",
"openai/gpt-4o",
"openai/gpt-4.1",
"google/gemini",
"meta-llama/llama-3.2",
"mistral/pixtral",
"qwen/qwen2.5-vl",
"qwen/qwen-vl",
}
)
def is_valid_bedrock_model(
model_id: str,
@@ -76,11 +96,18 @@ def is_valid_bedrock_model(
def infer_vision_support(model_id: str) -> bool:
"""Infer vision support from model ID when base model metadata unavailable.
Used for cross-region inference profiles when the base model isn't
available in the user's region.
Used for providers like Bedrock and Bifrost where vision support may
need to be inferred from vendor/model naming conventions.
"""
model_id_lower = model_id.lower()
return any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS)
if any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS):
return True
normalized_model_id = model_id_lower.replace(".", "/")
return any(
vision_model in normalized_model_id
for vision_model in BIFROST_VISION_MODEL_FAMILIES
)
def generate_bedrock_display_name(model_id: str) -> str:
@@ -322,7 +349,7 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
- Ollama: "llama3:70b""Meta"
- Ollama: "qwen2.5:7b""Alibaba"
"""
if provider == LlmProviderNames.OPENROUTER:
if provider in (LlmProviderNames.OPENROUTER, LlmProviderNames.BIFROST):
# Format: "vendor/model-name" e.g., "anthropic/claude-3-5-sonnet"
if "/" in model_name:
vendor_key = model_name.split("/")[0].lower()

View File

@@ -449,40 +449,128 @@ class RedisHealthCollector(_CachedCollector):
return [memory_used, memory_peak, memory_frag, connected_clients]
class WorkerHealthCollector(_CachedCollector):
"""Collects Celery worker count and process count via inspect ping.
class WorkerHeartbeatMonitor:
"""Monitors Celery worker health via the event stream.
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
command that takes a couple seconds to complete.
Maintains a set of known worker short-names so that when a worker
stops responding, we emit ``up=0`` instead of silently dropping the
metric (which would make ``absent()``-style alerts impossible).
Subscribes to ``worker-heartbeat``, ``worker-online``, and
``worker-offline`` events via a single persistent connection.
Runs in a daemon thread started once during worker setup.
"""
# Remove a worker from _known_workers after this many consecutive
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
_MAX_CONSECUTIVE_MISSES = 10
# Consider a worker down if no heartbeat received for this long.
_HEARTBEAT_TIMEOUT_SECONDS = 120.0
def __init__(self, cache_ttl: float = 60.0) -> None:
def __init__(self, celery_app: Any) -> None:
self._app = celery_app
self._worker_last_seen: dict[str, float] = {}
self._lock = threading.Lock()
self._running = False
self._thread: threading.Thread | None = None
def start(self) -> None:
"""Start the background event listener thread.
Safe to call multiple times — only starts one thread.
"""
if self._thread is not None and self._thread.is_alive():
return
self._running = True
self._thread = threading.Thread(target=self._listen, daemon=True)
self._thread.start()
logger.info("WorkerHeartbeatMonitor started")
def stop(self) -> None:
self._running = False
def _listen(self) -> None:
"""Background loop: connect to event stream and process heartbeats."""
while self._running:
try:
with self._app.connection() as conn:
recv = self._app.events.Receiver(
conn,
handlers={
"worker-heartbeat": self._on_heartbeat,
"worker-online": self._on_heartbeat,
"worker-offline": self._on_offline,
},
)
recv.capture(
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
)
except Exception:
if self._running:
logger.debug(
"Heartbeat listener disconnected, reconnecting in 5s",
exc_info=True,
)
time.sleep(5.0)
else:
# capture() returned normally (timeout with no events); reconnect
if self._running:
logger.debug("Heartbeat capture timed out, reconnecting")
time.sleep(5.0)
def _on_heartbeat(self, event: dict[str, Any]) -> None:
hostname = event.get("hostname")
if hostname:
with self._lock:
self._worker_last_seen[hostname] = time.monotonic()
def _on_offline(self, event: dict[str, Any]) -> None:
hostname = event.get("hostname")
if hostname:
with self._lock:
self._worker_last_seen.pop(hostname, None)
def get_worker_status(self) -> dict[str, bool]:
"""Return {hostname: is_alive} for all known workers.
Thread-safe. Called by WorkerHealthCollector on each scrape.
Also prunes workers that have been dead longer than 2x the
heartbeat timeout to prevent unbounded growth.
"""
now = time.monotonic()
prune_threshold = self._HEARTBEAT_TIMEOUT_SECONDS * 2
with self._lock:
# Prune workers that have been gone for 2x the timeout
stale = [
h
for h, ts in self._worker_last_seen.items()
if (now - ts) > prune_threshold
]
for h in stale:
del self._worker_last_seen[h]
result: dict[str, bool] = {}
for hostname, last_seen in self._worker_last_seen.items():
alive = (now - last_seen) < self._HEARTBEAT_TIMEOUT_SECONDS
result[hostname] = alive
return result
class WorkerHealthCollector(_CachedCollector):
"""Collects Celery worker health from the heartbeat monitor.
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
to the Celery event stream via a single persistent connection.
"""
def __init__(self, cache_ttl: float = 30.0) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
# worker short-name → consecutive miss count.
# Workers start at 0 and reset to 0 each time they respond.
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
self._known_workers: dict[str, int] = {}
self._monitor: WorkerHeartbeatMonitor | None = None
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app instance for inspect commands."""
self._celery_app = app
def set_monitor(self, monitor: WorkerHeartbeatMonitor) -> None:
"""Set the heartbeat monitor instance."""
self._monitor = monitor
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
if self._monitor is None:
return []
active_workers = GaugeMetricFamily(
"onyx_celery_active_worker_count",
"Number of active Celery workers responding to ping",
"Number of active Celery workers with recent heartbeats",
)
worker_up = GaugeMetricFamily(
"onyx_celery_worker_up",
@@ -491,37 +579,15 @@ class WorkerHealthCollector(_CachedCollector):
)
try:
inspector = self._celery_app.control.inspect(timeout=3.0)
ping_result = inspector.ping()
status = self._monitor.get_worker_status()
alive_count = sum(1 for alive in status.values() if alive)
active_workers.add_metric([], alive_count)
responding: set[str] = set()
if ping_result:
active_workers.add_metric([], len(ping_result))
for worker_name in ping_result:
# Strip hostname suffix for cleaner labels
short_name = worker_name.split("@")[0]
responding.add(short_name)
else:
active_workers.add_metric([], 0)
# Register newly-seen workers and reset miss count for
# workers that responded.
for short_name in responding:
self._known_workers[short_name] = 0
# Increment miss count for non-responding workers and evict
# those that have been missing too long.
stale = []
for short_name in list(self._known_workers):
if short_name not in responding:
self._known_workers[short_name] += 1
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
stale.append(short_name)
for short_name in stale:
del self._known_workers[short_name]
for short_name in sorted(self._known_workers):
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
for hostname in sorted(status):
# Use short name (before @) for single-host deployments,
# full hostname when multiple hosts share a worker type.
label = hostname.split("@")[0]
worker_up.add_metric([label], 1 if status[hostname] else 0)
except Exception:
logger.debug("Failed to collect worker health metrics", exc_info=True)

View File

@@ -15,6 +15,7 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -28,6 +29,7 @@ _attempt_collector = IndexAttemptCollector()
_connector_collector = ConnectorHealthCollector()
_redis_health_collector = RedisHealthCollector()
_worker_health_collector = WorkerHealthCollector()
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
@@ -96,7 +98,16 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
_worker_health_collector.set_celery_app(celery_app)
# Start the heartbeat monitor daemon thread — uses a single persistent
# connection to receive worker-heartbeat events.
# Module-level singleton prevents duplicate threads on re-entry.
global _heartbeat_monitor
if _heartbeat_monitor is None:
_heartbeat_monitor = WorkerHeartbeatMonitor(celery_app)
_heartbeat_monitor.start()
_worker_health_collector.set_monitor(_heartbeat_monitor)
_attempt_collector.configure()
_connector_collector.configure()

View File

@@ -28,7 +28,6 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -62,8 +60,6 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -85,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -575,46 +570,6 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -705,30 +660,6 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -2,24 +2,11 @@ from pydantic import BaseModel
class Placement(BaseModel):
"""Coordinates that identify where a streaming packet belongs in the UI.
The frontend uses these fields to route each packet to the correct turn,
tool tab, agent sub-turn, and (in multi-model mode) response column.
Attributes:
turn_index: Monotonically increasing index of the iterative reasoning block
(e.g. tool call round) within this chat message. Lower values happened first.
tab_index: Disambiguates parallel tool calls within the same turn so each
tool's output can be displayed in its own tab.
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
top-level packets; an integer for tool-within-tool output.
model_index: Which model this packet belongs to in a multi-model comparison
(0, 1, or 2). ``None`` for single-model responses, preserving the
backwards-compatible wire format for existing API consumers.
"""
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -708,6 +708,7 @@ def run_research_agent_calls(
if __name__ == "__main__":
from queue import Queue
from uuid import uuid4
from onyx.chat.chat_state import ChatStateContainer
@@ -743,7 +744,8 @@ if __name__ == "__main__":
if user is None:
raise ValueError("No users found in database. Please create a user first.")
emitter = Emitter()
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
state_container = ChatStateContainer()
tool_dict = construct_tools(
@@ -790,4 +792,4 @@ if __name__ == "__main__":
print(result.intermediate_report)
print("=" * 80)
print(f"Citations: {result.citation_mapping}")
print(f"Total packets emitted: {emitter.bus.qsize()}")
print(f"Total packets emitted: {bus.qsize()}")

View File

@@ -103,6 +103,11 @@ _EXPECTED_CONFLUENCE_GROUPS = [
user_emails={"oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="no yuhong allowed",
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
gives_anyone_access=False,
),
]

View File

@@ -1,253 +0,0 @@
"""Unit tests for the Emitter class.
Covers both modes (standalone and streaming) without any real database,
LLM, or queue infrastructure beyond the stdlib Queue.
"""
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _placement(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
def _packet(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Packet:
"""Build a minimal valid packet with an OverallStop payload."""
return Packet(
placement=_placement(turn_index, tab_index, sub_turn_index),
obj=OverallStop(stop_reason="test"),
)
# ---------------------------------------------------------------------------
# Standalone mode (no merged_queue)
# ---------------------------------------------------------------------------
class TestEmitterStandaloneMode:
def test_emitted_packet_arrives_on_bus(self) -> None:
emitter = Emitter()
pkt = _packet()
emitter.emit(pkt)
assert emitter.bus.get_nowait() is pkt
def test_bus_is_empty_before_emit(self) -> None:
emitter = Emitter()
assert emitter.bus.empty()
def test_multiple_packets_delivered_fifo(self) -> None:
emitter = Emitter()
p1 = _packet(turn_index=0)
p2 = _packet(turn_index=1)
emitter.emit(p1)
emitter.emit(p2)
assert emitter.bus.get_nowait() is p1
assert emitter.bus.get_nowait() is p2
def test_packet_not_modified(self) -> None:
"""Standalone mode must not wrap or mutate the packet."""
emitter = Emitter()
pkt = _packet(turn_index=7, tab_index=3)
emitter.emit(pkt)
retrieved = emitter.bus.get_nowait()
assert retrieved.placement.turn_index == 7
assert retrieved.placement.tab_index == 3
def test_get_default_emitter_is_standalone(self) -> None:
emitter = get_default_emitter()
pkt = _packet()
emitter.emit(pkt)
# Packet lands on the bus, not a shared queue
assert emitter.bus.get_nowait() is pkt
# ---------------------------------------------------------------------------
# Streaming mode (merged_queue provided)
# ---------------------------------------------------------------------------
class TestEmitterStreamingMode:
# --- Queue routing ---
def test_packet_goes_to_merged_queue_not_bus(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet())
assert not mq.empty()
assert emitter.bus.empty()
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=1, merged_queue=mq)
emitter.emit(_packet())
item = mq.get_nowait()
assert isinstance(item, tuple)
assert len(item) == 2
# --- model_index tagging ---
def test_model_idx_none_preserves_model_index_none(self) -> None:
"""N=1 backwards-compat: model_index must stay None in the packet."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=None, merged_queue=mq)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index is None
def test_model_idx_zero_tags_packet(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 0
def test_model_idx_one_tags_packet(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=1, merged_queue=mq)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 1
def test_model_idx_two_tags_packet(self) -> None:
"""Boundary: third model in a 3-model run."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=2, merged_queue=mq)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 2
# --- Queue key ---
def test_key_equals_model_idx_when_set(self) -> None:
"""Drain loop uses the key to route packets; it must match model_idx."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=2, merged_queue=mq)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 2
def test_key_is_zero_when_model_idx_none(self) -> None:
"""N=1: key defaults to 0 (single slot in the drain loop)."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=None, merged_queue=mq)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0
# --- Placement field preservation ---
def test_turn_index_is_preserved(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet(turn_index=5))
_, tagged = mq.get_nowait()
assert tagged.placement.turn_index == 5
def test_tab_index_is_preserved(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet(tab_index=3))
_, tagged = mq.get_nowait()
assert tagged.placement.tab_index == 3
def test_sub_turn_index_is_preserved(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet(sub_turn_index=2))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index == 2
def test_sub_turn_index_none_is_preserved(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet(sub_turn_index=None))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index is None
def test_packet_obj_is_not_modified(self) -> None:
"""The payload object must survive tagging untouched."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
original_obj = OverallStop(stop_reason="sentinel")
pkt = Packet(placement=_placement(), obj=original_obj)
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert tagged.obj is original_obj
def test_different_obj_types_are_handled(self) -> None:
"""Any valid PacketObj type passes through correctly."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
pkt = Packet(placement=_placement(), obj=ReasoningStart())
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert isinstance(tagged.obj, ReasoningStart)
# --- bus is always created ---
def test_bus_exists_in_streaming_mode(self) -> None:
"""bus must always be present for backwards-compat with existing callers."""
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
assert hasattr(emitter, "bus")
assert isinstance(emitter.bus, queue.Queue)
def test_bus_stays_empty_in_streaming_mode(self) -> None:
import queue
mq: queue.Queue = queue.Queue()
emitter = Emitter(model_idx=0, merged_queue=mq)
emitter.emit(_packet())
assert emitter.bus.empty()

View File

@@ -1,636 +0,0 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in handle_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
import time
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.chat.models import StreamingError
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
"""Advance the generator one step to trigger early validation."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
# Calling next() executes until the first yield OR raises.
# Validation errors are raised before any yield.
next(gen)
# ---------------------------------------------------------------------------
# handle_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_raises(self) -> None:
"""Exactly 1 override is not multi-model — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
def test_four_overrides_raises(self) -> None:
"""4 overrides exceeds maximum — must raise."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
def test_zero_overrides_raises(self) -> None:
"""Empty override list raises."""
req = _make_request()
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [])
def test_deep_research_raises(self) -> None:
"""deep_research=True is incompatible with multi-model."""
req = _make_request(deep_research=True)
with pytest.raises(ValueError, match="not supported"):
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
req = _make_request()
# 1 override must fail
with pytest.raises(ValueError, match="2-3"):
_start_stream(req, [_make_override()])
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
try:
_start_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
except ValueError as exc:
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
except Exception:
pass # Any other error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"
# ---------------------------------------------------------------------------
# _run_models — drain loop behaviour
# ---------------------------------------------------------------------------
def _make_setup(n_models: int = 1) -> MagicMock:
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
setup = MagicMock()
setup.llms = [MagicMock() for _ in range(n_models)]
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
setup.check_is_connected = MagicMock(return_value=True)
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
setup.reserved_token_count = 100
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
# constructors inside _run_model — must be typed correctly for Pydantic.
setup.new_msg_req.deep_research = False
setup.new_msg_req.internal_search_filters = None
setup.new_msg_req.allowed_tool_ids = None
setup.new_msg_req.include_citations = True
setup.search_params.project_id_filter = None
setup.search_params.persona_id_filter = None
setup.bypass_acl = False
setup.slack_context = None
setup.available_files.user_file_ids = []
setup.available_files.chat_file_ids = []
setup.forced_tool_id = None
setup.simple_chat_history = []
setup.chat_session.id = uuid4()
setup.user_message.id = None
setup.custom_tool_additional_headers = None
setup.mcp_headers = None
return setup
_RUN_MODELS_PATCHES = [
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
]
def _run_models_collect(setup: MagicMock) -> list:
"""Drive _run_models to completion and return all yielded items."""
from onyx.chat.process_message import _run_models
return list(_run_models(setup, MagicMock(), MagicMock()))
class TestRunModels:
"""Tests for the _run_models worker-thread drain loop.
All external dependencies (LLM, DB, tools) are patched out. Worker threads
still run but return immediately since run_llm_loop is mocked.
"""
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
def emit_stop(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(stop_reason="complete"),
)
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert len(stops) == 1
stop_obj = stops[0].obj
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_none(self) -> None:
"""Single-model path: model_index stays None for wire backwards-compat."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index is None
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
def emit_one(**kwargs: Any) -> None:
# _model_idx is set by _run_model based on position in setup.llms
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 2
indices = {p.placement.model_index for p in reasoning}
assert indices == {0, 1}
def test_model_error_yields_streaming_error(self) -> None:
"""An exception inside a worker thread is surfaced as a StreamingError."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("intentional test failure")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
assert errors[0].error_code == "MODEL_ERROR"
assert "intentional test failure" in errors[0].error
def test_one_model_error_does_not_stop_other_models(self) -> None:
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
emitter = kwargs["emitter"]
# _model_idx is None for N=1, int for N>1
if emitter._model_idx == 0:
raise RuntimeError("model 0 failed")
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=fail_model_0_succeed_model_1,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 1
def test_cancellation_yields_user_cancelled_stop(self) -> None:
"""If check_is_connected returns False, drain loop emits user_cancelled."""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert any(
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
for s in stops
)
def test_completion_handle_called_on_disconnect(self) -> None:
"""llm_loop_completion_handle must still be called even when user disconnects.
Regression test for the disconnect-cleanup bug: the old
run_chat_loop_with_state_containers always called completion_callback in
its finally block (even on disconnect) so the DB message was updated from
the TERMINATED placeholder to a partial answer. The new _run_models must
replicate this — otherwise the integration test
test_send_message_disconnect_and_cleanup fails because the message stays
as "Response was terminated prior to completion, try regenerating."
"""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3)
setup = _make_setup(n_models=2)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Must be called once per model, not zero times
assert mock_handle.call_count == 2
def test_completion_handle_called_for_each_successful_model(self) -> None:
"""llm_loop_completion_handle must be called once per model that succeeded."""
setup = _make_setup(n_models=2)
with (
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
assert mock_handle.call_count == 2
def test_completion_handle_not_called_for_failed_model(self) -> None:
"""llm_loop_completion_handle must be skipped for a model that raised."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("fail")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(_make_setup(n_models=1))
mock_handle.assert_not_called()
def test_http_disconnect_completion_via_generator_exit(self) -> None:
"""GeneratorExit from HTTP disconnect triggers wait+completion in finally.
When the HTTP client closes the connection, Starlette throws GeneratorExit
into the stream generator, which propagates into _run_models. The finally
block must call executor.shutdown(wait=True) to wait for LLM threads to
finish, then persist their results via llm_loop_completion_handle.
This is the primary regression for test_send_message_disconnect_and_cleanup:
the integration test disconnects mid-stream and expects the DB message to be
updated from the TERMINATED placeholder to the real response.
"""
import threading
thread_completed = threading.Event()
def emit_then_complete(**kwargs: Any) -> None:
"""Emit one packet (to give generator a yield point), then finish."""
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
# Small sleep so executor.shutdown(wait=True) in finally actually waits.
time.sleep(0.05)
thread_completed.set()
setup = _make_setup(n_models=1)
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_then_complete,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
gen = _run_models(setup, MagicMock(), MagicMock())
# Advance to the first yielded packet — generator suspends at `yield item`.
first = next(gen)
assert isinstance(first, Packet)
# Simulate Starlette closing the stream on HTTP client disconnect.
# GeneratorExit is thrown at the `yield item` suspension point.
gen.close()
# Finally block must have waited for the thread and saved completion.
assert (
thread_completed.is_set()
), "LLM thread must complete before gen.close() returns"
assert (
mock_handle.call_count == 1
), "completion handle must be called for the successful model"
def test_external_state_container_used_for_model_zero(self) -> None:
"""When provided, external_state_container is used as state_containers[0]."""
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.process_message import _run_models
external = ChatStateContainer()
setup = _make_setup(n_models=1)
with (
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
list(
_run_models(
setup, MagicMock(), MagicMock(), external_state_container=external
)
)
# The state_container kwarg passed to run_llm_loop must be the external one
call_kwargs = mock_llm.call_args.kwargs
assert call_kwargs["state_container"] is external

View File

@@ -0,0 +1,381 @@
"""Tests for Canvas connector — client (PR1)."""
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
FAKE_BASE_URL = "https://myschool.instructure.com"
FAKE_TOKEN = "fake-canvas-token"
def _mock_response(
status_code: int = 200,
json_data: Any = None,
link_header: str = "",
) -> MagicMock:
"""Create a mock HTTP response with status, json, and Link header."""
resp = MagicMock()
resp.status_code = status_code
resp.reason = "OK" if status_code < 300 else "Error"
resp.json.return_value = json_data if json_data is not None else []
resp.headers = {"Link": link_header}
return resp
# ---------------------------------------------------------------------------
# CanvasApiClient.__init__ tests
# ---------------------------------------------------------------------------
class TestCanvasApiClientInit:
def test_success(self) -> None:
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
expected_host = "myschool.instructure.com"
assert client.base_url == expected_base_url
assert client._expected_host == expected_host
def test_normalizes_trailing_slash(self) -> None:
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=f"{FAKE_BASE_URL}/",
)
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
assert client.base_url == expected_base_url
def test_normalizes_existing_api_v1(self) -> None:
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=f"{FAKE_BASE_URL}/api/v1",
)
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
assert client.base_url == expected_base_url
def test_rejects_non_https_scheme(self) -> None:
with pytest.raises(ValueError, match="must use https"):
CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url="ftp://myschool.instructure.com",
)
def test_rejects_http(self) -> None:
with pytest.raises(ValueError, match="must use https"):
CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url="http://myschool.instructure.com",
)
def test_rejects_missing_host(self) -> None:
with pytest.raises(ValueError, match="must include a valid host"):
CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url="https://",
)
# ---------------------------------------------------------------------------
# CanvasApiClient._build_url tests
# ---------------------------------------------------------------------------
class TestBuildUrl:
def setup_method(self) -> None:
self.client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
def test_appends_endpoint(self) -> None:
result = self.client._build_url("courses")
expected = f"{FAKE_BASE_URL}/api/v1/courses"
assert result == expected
def test_strips_leading_slash_from_endpoint(self) -> None:
result = self.client._build_url("/courses")
expected = f"{FAKE_BASE_URL}/api/v1/courses"
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient._build_headers tests
# ---------------------------------------------------------------------------
class TestBuildHeaders:
def setup_method(self) -> None:
self.client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
def test_returns_bearer_auth(self) -> None:
result = self.client._build_headers()
expected = {"Authorization": f"Bearer {FAKE_TOKEN}"}
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient.get tests
# ---------------------------------------------------------------------------
class TestGet:
def setup_method(self) -> None:
self.client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_success_returns_json_and_next_url(self, mock_requests: MagicMock) -> None:
next_link = f"<{FAKE_BASE_URL}/api/v1/courses?page=2>; " 'rel="next"'
mock_requests.get.return_value = _mock_response(
json_data=[{"id": 1}], link_header=next_link
)
data, next_url = self.client.get("courses")
expected_data = [{"id": 1}]
expected_next = f"{FAKE_BASE_URL}/api/v1/courses?page=2"
assert data == expected_data
assert next_url == expected_next
@patch("onyx.connectors.canvas.client.rl_requests")
def test_success_no_next_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[{"id": 1}])
data, next_url = self.client.get("courses")
assert data == [{"id": 1}]
assert next_url is None
@patch("onyx.connectors.canvas.client.rl_requests")
def test_raises_on_error_status(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(403, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
assert exc_info.value.status_code == 403
@patch("onyx.connectors.canvas.client.rl_requests")
def test_raises_on_404(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(404, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
assert exc_info.value.status_code == 404
@patch("onyx.connectors.canvas.client.rl_requests")
def test_raises_on_429(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(429, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
assert exc_info.value.status_code == 429
@patch("onyx.connectors.canvas.client.rl_requests")
def test_skips_params_when_using_full_url(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
full = f"{FAKE_BASE_URL}/api/v1/courses?page=2"
self.client.get(params={"per_page": "100"}, full_url=full)
_, kwargs = mock_requests.get.call_args
assert kwargs["params"] is None
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_error_dict(
self, mock_requests: MagicMock
) -> None:
"""Shape 1: {"error": {"message": "Not authorized"}}"""
mock_requests.get.return_value = _mock_response(
403, {"error": {"message": "Not authorized"}}
)
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Not authorized"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_error_string(
self, mock_requests: MagicMock
) -> None:
"""Shape 2: {"error": "Invalid access token"}"""
mock_requests.get.return_value = _mock_response(
401, {"error": "Invalid access token"}
)
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Invalid access token"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_extracts_message_from_errors_list(
self, mock_requests: MagicMock
) -> None:
"""Shape 3: {"errors": [{"message": "Invalid query"}]}"""
mock_requests.get.return_value = _mock_response(
400, {"errors": [{"message": "Invalid query"}]}
)
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Invalid query"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_dict_takes_priority_over_errors_list(
self, mock_requests: MagicMock
) -> None:
"""When both error shapes are present, error dict wins."""
mock_requests.get.return_value = _mock_response(
403, {"error": "Specific error", "errors": [{"message": "Generic"}]}
)
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Specific error"
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_error_falls_back_to_reason_when_no_json_message(
self, mock_requests: MagicMock
) -> None:
"""Empty error body falls back to response.reason."""
mock_requests.get.return_value = _mock_response(500, {})
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Error" # from _mock_response's reason for >= 300
assert result == expected
@patch("onyx.connectors.canvas.client.rl_requests")
def test_invalid_json_on_success_raises(self, mock_requests: MagicMock) -> None:
"""Invalid JSON on a 2xx response raises OnyxError."""
resp = MagicMock()
resp.status_code = 200
resp.json.side_effect = ValueError("No JSON")
resp.headers = {"Link": ""}
mock_requests.get.return_value = resp
with pytest.raises(OnyxError, match="Invalid JSON"):
self.client.get("courses")
@patch("onyx.connectors.canvas.client.rl_requests")
def test_invalid_json_on_error_falls_back_to_reason(
self, mock_requests: MagicMock
) -> None:
"""Invalid JSON on a 4xx response falls back to response.reason."""
resp = MagicMock()
resp.status_code = 500
resp.reason = "Internal Server Error"
resp.json.side_effect = ValueError("No JSON")
resp.headers = {"Link": ""}
mock_requests.get.return_value = resp
with pytest.raises(OnyxError) as exc_info:
self.client.get("courses")
result = exc_info.value.detail
expected = "Internal Server Error"
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient._parse_next_link tests
# ---------------------------------------------------------------------------
class TestParseNextLink:
def setup_method(self) -> None:
self.client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url="https://canvas.example.com",
)
def test_found(self) -> None:
header = '<https://canvas.example.com/api/v1/courses?page=2>; rel="next"'
result = self.client._parse_next_link(header)
expected = "https://canvas.example.com/api/v1/courses?page=2"
assert result == expected
def test_not_found(self) -> None:
header = '<https://canvas.example.com/api/v1/courses?page=1>; rel="current"'
result = self.client._parse_next_link(header)
assert result is None
def test_empty(self) -> None:
result = self.client._parse_next_link("")
assert result is None
def test_multiple_rels(self) -> None:
header = (
'<https://canvas.example.com/api/v1/courses?page=1>; rel="current", '
'<https://canvas.example.com/api/v1/courses?page=2>; rel="next"'
)
result = self.client._parse_next_link(header)
expected = "https://canvas.example.com/api/v1/courses?page=2"
assert result == expected
def test_rejects_host_mismatch(self) -> None:
header = '<https://evil.example.com/api/v1/courses?page=2>; rel="next"'
with pytest.raises(OnyxError, match="unexpected host"):
self.client._parse_next_link(header)
def test_rejects_non_https_link(self) -> None:
header = '<http://canvas.example.com/api/v1/courses?page=2>; rel="next"'
with pytest.raises(OnyxError, match="must use https"):
self.client._parse_next_link(header)

View File

@@ -0,0 +1,147 @@
from typing import Any
from unittest.mock import MagicMock
import pytest
import requests
from jira import JIRA
from jira.resources import Issue
from onyx.connectors.jira.connector import bulk_fetch_issues
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
return {
"id": issue_id,
"key": f"TEST-{issue_id}",
"fields": {"summary": f"Issue {issue_id}"},
}
def _mock_jira_client() -> MagicMock:
mock = MagicMock(spec=JIRA)
mock._options = {"server": "https://jira.example.com"}
mock._session = MagicMock()
mock._get_url = MagicMock(
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
)
return mock
def test_bulk_fetch_success() -> None:
"""Happy path: all issues fetched in one request."""
client = _mock_jira_client()
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
result = bulk_fetch_issues(client, ["1", "2", "3"])
assert len(result) == 3
assert all(isinstance(r, Issue) for r in result)
client._session.post.assert_called_once()
def test_bulk_fetch_splits_on_json_error() -> None:
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
client = _mock_jira_client()
call_count = 0
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
nonlocal call_count
call_count += 1
ids = json["issueIdsOrKeys"]
if len(ids) > 2:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 2294125
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
assert len(result) == 4
returned_ids = {r.raw["id"] for r in result}
assert returned_ids == {"1", "2", "3", "4"}
assert call_count > 1
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
"""A single issue that always fails JSON decode raises after splitting."""
client = _mock_jira_client()
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if "bad" in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"Expecting ',' delimiter", "doc", 100
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "bad", "2"])
def test_bulk_fetch_non_json_error_propagates() -> None:
"""Non-JSONDecodeError exceptions still propagate."""
client = _mock_jira_client()
resp = MagicMock()
resp.json.side_effect = ValueError("something else broke")
client._session.post.return_value = resp
try:
bulk_fetch_issues(client, ["1"])
assert False, "Expected ValueError to propagate"
except ValueError:
pass
def test_bulk_fetch_with_fields() -> None:
"""Fields parameter is forwarded correctly."""
client = _mock_jira_client()
raw = [_make_raw_issue("1")]
resp = MagicMock()
resp.json.return_value = {"issues": raw}
client._session.post.return_value = resp
bulk_fetch_issues(client, ["1"], fields="summary,description")
call_payload = client._session.post.call_args[1]["json"]
assert call_payload["fields"] == ["summary", "description"]
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
client = _mock_jira_client()
bad_id = "BAD"
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
ids = json["issueIdsOrKeys"]
if bad_id in ids:
resp = MagicMock()
resp.json.side_effect = requests.exceptions.JSONDecodeError(
"truncated", "doc", 999
)
return resp
resp = MagicMock()
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
return resp
client._session.post.side_effect = _post_side_effect
with pytest.raises(requests.exceptions.JSONDecodeError):
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])

View File

@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
class TestDeleteVoiceProvider:
"""Tests for delete_voice_provider."""
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
delete_voice_provider(mock_db_session, 1)
assert provider.deleted is True
mock_db_session.delete.assert_called_once_with(provider)
mock_db_session.flush.assert_called_once()
def test_does_nothing_when_provider_not_found(

View File

@@ -1462,3 +1462,69 @@ def test_no_tool_choice_sent_when_no_tools(default_multi_llm: LitellmLLM) -> Non
assert (
"tool_choice" not in kwargs
), "tool_choice must not be sent to providers when no tools are provided"
def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
llm = LitellmLLM(
api_key="test_key",
api_base="https://bifrost.example.com/",
timeout=30,
model_provider=LlmProviderNames.BIFROST,
model_name="anthropic/claude-sonnet-4-6",
max_input_tokens=32000,
)
assert llm._custom_llm_provider == "openai"
assert llm._api_base == "https://bifrost.example.com/v1"
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
def test_bifrost_claude_includes_allowed_openai_params() -> None:
llm = LitellmLLM(
api_key="test_key",
api_base="https://bifrost.example.com",
timeout=30,
model_provider=LlmProviderNames.BIFROST,
model_name="anthropic/claude-sonnet-4-6",
max_input_tokens=32000,
)
messages: LanguageModelInput = [UserMessage(content="Use a tool if needed")]
tools = [
{
"type": "function",
"function": {
"name": "lookup",
"description": "Look up data",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
}
]
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Done"),
finish_reason="stop",
index=0,
)
],
model="anthropic/claude-sonnet-4-6",
),
]
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = mock_stream_chunks
llm.invoke(messages, tools=tools)
kwargs = mock_completion.call_args.kwargs
assert kwargs["model"] == "anthropic/claude-sonnet-4-6"
assert kwargs["base_url"] == "https://bifrost.example.com/v1"
assert kwargs["custom_llm_provider"] == "openai"
assert kwargs["allowed_openai_params"] == ["tool_choice"]

View File

@@ -3,7 +3,7 @@
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
# ---------------------------------------------------------------------------
@@ -158,13 +158,11 @@ class TestValidateEndpoint:
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.cannot_connect
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(

View File

@@ -12,6 +12,8 @@ import httpx
import pytest
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.llm.models import BifrostFinalModelResponse
from onyx.server.manage.llm.models import BifrostModelsRequest
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelsRequest
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
@@ -850,13 +852,15 @@ class TestGetLitellmAvailableModels:
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_get.side_effect = Exception("Connection refused")
mock_get.side_effect = httpx.ConnectError(
"Connection refused", request=MagicMock()
)
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM proxy models"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_401_raises_authentication_error(self) -> None:
@@ -898,3 +902,113 @@ class TestGetLitellmAvailableModels:
)
with pytest.raises(OnyxError, match="endpoint not found"):
get_litellm_available_models(request, MagicMock(), mock_session)
class TestGetBifrostAvailableModels:
"""Tests for the Bifrost model fetch endpoint."""
@pytest.fixture
def mock_bifrost_response(self) -> dict:
"""Mock response from Bifrost /v1/models endpoint."""
return {
"data": [
{
"id": "anthropic/claude-3-5-sonnet",
"name": "Claude 3.5 Sonnet",
"context_length": 200000,
},
{
"id": "openai/gpt-4o",
"name": "GPT-4o",
"context_length": 128000,
},
{
"id": "deepseek/deepseek-r1",
"name": "DeepSeek R1",
"context_length": 64000,
},
]
}
def test_returns_model_list(self, mock_bifrost_response: dict) -> None:
"""Test that endpoint returns properly formatted non-embedding models."""
from onyx.server.manage.llm.api import get_bifrost_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_bifrost_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
results = get_bifrost_available_models(request, MagicMock(), mock_session)
assert len(results) == 3
assert all(isinstance(r, BifrostFinalModelResponse) for r in results)
assert [r.name for r in results] == sorted(
[r.name for r in results], key=str.lower
)
def test_infers_vision_support(self, mock_bifrost_response: dict) -> None:
"""Test that vision support is inferred from provider/model IDs."""
from onyx.server.manage.llm.api import get_bifrost_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_bifrost_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
results = get_bifrost_available_models(request, MagicMock(), mock_session)
claude = next(r for r in results if r.name == "anthropic/claude-3-5-sonnet")
gpt4o = next(r for r in results if r.name == "openai/gpt-4o")
deepseek = next(r for r in results if r.name == "deepseek/deepseek-r1")
assert claude.supports_image_input is True
assert gpt4o.supports_image_input is True
assert deepseek.supports_image_input is False
def test_existing_v1_suffix_is_not_duplicated(self) -> None:
"""Test that an existing /v1 suffix still hits a single /v1/models endpoint."""
from onyx.server.manage.llm.api import get_bifrost_available_models
mock_session = MagicMock()
response = {"data": [{"id": "openai/gpt-4o", "name": "GPT-4o"}]}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = BifrostModelsRequest(api_base="https://bifrost.example.com/v1")
get_bifrost_available_models(request, MagicMock(), mock_session)
called_url = mock_get.call_args[0][0]
assert called_url == "https://bifrost.example.com/v1/models"
def test_request_failure_is_logged_and_wrapped(self) -> None:
"""Test that request-layer failures are logged before raising OnyxError."""
from onyx.server.manage.llm.api import get_bifrost_available_models
mock_session = MagicMock()
with (
patch("onyx.server.manage.llm.api.httpx.get") as mock_get,
patch("onyx.server.manage.llm.api.logger.warning") as mock_warning,
):
mock_get.side_effect = httpx.ConnectError(
"Connection refused", request=MagicMock()
)
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
with pytest.raises(OnyxError, match="Failed to fetch Bifrost models"):
get_bifrost_available_models(request, MagicMock(), mock_session)
mock_warning.assert_called_once()

View File

@@ -176,6 +176,14 @@ class TestInferVisionSupport:
"""Test Nova Pro has vision."""
assert infer_vision_support("amazon.nova-pro-v1") is True
def test_bifrost_claude_has_vision(self) -> None:
"""Test Bifrost Claude models are recognized as vision-capable."""
assert infer_vision_support("anthropic/claude-3-5-sonnet") is True
def test_bifrost_gpt4o_has_vision(self) -> None:
"""Test Bifrost GPT-4o models are recognized as vision-capable."""
assert infer_vision_support("openai/gpt-4o") is True
def test_mistral_no_vision(self) -> None:
"""Test Mistral doesn't have vision (not in known list)."""
assert infer_vision_support("mistral.mistral-large") is False

View File

@@ -0,0 +1,170 @@
"""Tests for WorkerHeartbeatMonitor and WorkerHealthCollector."""
import time
from unittest.mock import MagicMock
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
class TestWorkerHeartbeatMonitor:
def test_heartbeat_registers_worker(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "primary@host1"})
status = monitor.get_worker_status()
assert "primary@host1" in status
assert status["primary@host1"] is True
def test_multiple_workers(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "primary@host1"})
monitor._on_heartbeat({"hostname": "docfetching@host1"})
monitor._on_heartbeat({"hostname": "monitoring@host1"})
status = monitor.get_worker_status()
assert len(status) == 3
assert all(alive for alive in status.values())
def test_offline_removes_worker(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "primary@host1"})
monitor._on_offline({"hostname": "primary@host1"})
status = monitor.get_worker_status()
assert "primary@host1" not in status
def test_stale_heartbeat_marks_worker_down(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
with monitor._lock:
monitor._worker_last_seen["primary@host1"] = (
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
)
status = monitor.get_worker_status()
assert status["primary@host1"] is False
def test_very_stale_worker_is_pruned(self) -> None:
"""Workers dead for 2x the timeout are pruned from the dict."""
monitor = WorkerHeartbeatMonitor(MagicMock())
with monitor._lock:
monitor._worker_last_seen["gone@host1"] = (
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS * 2 - 10
)
status = monitor.get_worker_status()
assert "gone@host1" not in status
assert monitor.get_worker_status() == {}
def test_heartbeat_refreshes_stale_worker(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
with monitor._lock:
monitor._worker_last_seen["primary@host1"] = (
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
)
assert monitor.get_worker_status()["primary@host1"] is False
monitor._on_heartbeat({"hostname": "primary@host1"})
assert monitor.get_worker_status()["primary@host1"] is True
def test_ignores_empty_hostname(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({})
monitor._on_heartbeat({"hostname": ""})
monitor._on_offline({})
assert monitor.get_worker_status() == {}
def test_returns_full_hostname_as_key(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "docprocessing@my-long-host.local"})
status = monitor.get_worker_status()
assert "docprocessing@my-long-host.local" in status
def test_start_is_idempotent(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
# Mock the thread so we don't actually start one
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
monitor._thread = mock_thread
monitor._running = True
# Second start should be a no-op
monitor.start()
# Thread constructor should not have been called again
assert monitor._thread is mock_thread
def test_thread_safety(self) -> None:
"""get_worker_status should not raise even if heartbeats arrive concurrently."""
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "primary@host1"})
status = monitor.get_worker_status()
monitor._on_heartbeat({"hostname": "primary@host1"})
status2 = monitor.get_worker_status()
assert status == status2
class TestWorkerHealthCollector:
def test_returns_empty_when_no_monitor(self) -> None:
collector = WorkerHealthCollector(cache_ttl=0)
assert collector.collect() == []
def test_collects_active_workers(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "primary@host1"})
monitor._on_heartbeat({"hostname": "docfetching@host1"})
monitor._on_heartbeat({"hostname": "monitoring@host1"})
collector = WorkerHealthCollector(cache_ttl=0)
collector.set_monitor(monitor)
families = collector.collect()
assert len(families) == 2
active = families[0]
assert active.name == "onyx_celery_active_worker_count"
assert active.samples[0].value == 3
up = families[1]
assert up.name == "onyx_celery_worker_up"
assert len(up.samples) == 3
# Labels use short names (before @)
labels = {s.labels["worker"] for s in up.samples}
assert labels == {"primary", "docfetching", "monitoring"}
for sample in up.samples:
assert sample.value == 1
def test_reports_dead_worker(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
monitor._on_heartbeat({"hostname": "primary@host1"})
with monitor._lock:
monitor._worker_last_seen["monitoring@host1"] = (
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
)
collector = WorkerHealthCollector(cache_ttl=0)
collector.set_monitor(monitor)
families = collector.collect()
active = families[0]
assert active.samples[0].value == 1
up = families[1]
samples_by_name = {s.labels["worker"]: s.value for s in up.samples}
assert samples_by_name["primary"] == 1
assert samples_by_name["monitoring"] == 0
def test_empty_monitor_returns_zero(self) -> None:
monitor = WorkerHeartbeatMonitor(MagicMock())
collector = WorkerHealthCollector(cache_ttl=0)
collector.set_monitor(monitor)
families = collector.collect()
assert len(families) == 2
active = families[0]
assert active.samples[0].value == 0
up = families[1]
assert up.name == "onyx_celery_worker_up"
assert len(up.samples) == 0

View File

@@ -1,5 +1,6 @@
"""Tests for memory tool streaming packet emissions."""
from queue import Queue
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -18,7 +19,8 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
@pytest.fixture
def emitter() -> Emitter:
return Emitter()
bus: Queue = Queue()
return Emitter(bus)
@pytest.fixture

View File

@@ -1,128 +0,0 @@
{
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 2,
"statusCheck": true,
"commentTypes": [
"logic",
"syntax",
"style"
],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": [
"dependabot[bot]",
"renovate[bot]"
],
"ignoreKeywords": "",
"ignorePatterns": "greptile.json\n",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": true,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"customContext": {
"other": [
{
"scope": [],
"content": "Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code."
},
{
"scope": [],
"content": "Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly."
}
],
"rules": [
{
"scope": [],
"rule": "Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of TODO(name): ... or TODO(1234): ..."
},
{
"scope": ["web/**"],
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/AGENTS.md file."
},
{
"scope": [],
"rule": "Remove temporary debugging code before merging to production, especially tenant-specific debugging logs."
},
{
"scope": [],
"rule": "When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant."
},
{
"scope": [],
"rule": "Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies."
},
{
"scope": [],
"rule": "Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments."
},
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
],
"files": [
{
"scope": [],
"path": "contributing_guides/best_practices.md",
"description": "Best practices for contributing to the codebase"
},
{
"scope": [],
"path": "CLAUDE.md",
"description": "Project instructions and coding standards"
},
{
"scope": [],
"path": "backend/alembic/README.md",
"description": "Migration guidance, including multi-tenant migration behavior"
},
{
"scope": [],
"path": "deployment/helm/charts/onyx/values-lite.yaml",
"description": "Lite deployment Helm values and service assumptions"
},
{
"scope": [],
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
}
]
}
}

View File

@@ -342,9 +342,9 @@ visible text in the DOM (e.g., `title`, `description`, `label`) should be typed
```typescript
import type { RichStr } from "@opal/types";
import { resolveStr } from "@opal/components/text/InlineMarkdown";
import { Text } from "@opal/components";
// ✅ Good — new components accept string | RichStr
// ✅ Good — new components accept string | RichStr and render via Text
interface InfoCardProps {
title: string | RichStr;
description?: string | RichStr;
@@ -353,9 +353,9 @@ interface InfoCardProps {
function InfoCard({ title, description }: InfoCardProps) {
return (
<div>
<Text font="main-ui-action">{resolveStr(title)}</Text>
<Text font="main-ui-action">{title}</Text>
{description && (
<Text font="secondary-body" color="text-03">{resolveStr(description)}</Text>
<Text font="secondary-body" color="text-03">{description}</Text>
)}
</div>
);

View File

@@ -4,11 +4,15 @@ import {
Interactive,
type InteractiveStatelessProps,
} from "@opal/core";
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
import type {
ContainerSizeVariants,
ExtremaSizeVariants,
RichStr,
} from "@opal/types";
import { Text } from "@opal/components";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
import { cn } from "@opal/utils";
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
// ---------------------------------------------------------------------------
@@ -18,13 +22,13 @@ import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
type ButtonContentProps =
| {
icon?: IconFunctionComponent;
children: string;
children: string | RichStr;
rightIcon?: IconFunctionComponent;
responsiveHideText?: never;
}
| {
icon: IconFunctionComponent;
children?: string;
children?: string | RichStr;
rightIcon?: IconFunctionComponent;
responsiveHideText?: boolean;
};
@@ -69,15 +73,24 @@ function Button({
const isLarge = size === "lg";
const labelEl = children ? (
<span
className={cn(
"whitespace-nowrap",
isLarge ? "font-main-ui-body " : "font-secondary-body",
responsiveHideText && "hidden md:inline"
)}
>
{children}
</span>
responsiveHideText ? (
<span className="hidden md:inline whitespace-nowrap">
<Text
font={isLarge ? "main-ui-body" : "secondary-body"}
color="inherit"
>
{children}
</Text>
</span>
) : (
<Text
font={isLarge ? "main-ui-body" : "secondary-body"}
color="inherit"
nowrap
>
{children}
</Text>
)
) : null;
const button = (

View File

@@ -4,7 +4,8 @@ import {
type InteractiveStatefulProps,
} from "@opal/core";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
import type { IconFunctionComponent, RichStr } from "@opal/types";
import { Text } from "@opal/components";
import { SvgX } from "@opal/icons";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
@@ -16,12 +17,12 @@ import { Button } from "@opal/components/buttons/button/components";
// ---------------------------------------------------------------------------
interface FilterButtonProps
extends Omit<InteractiveStatefulProps, "variant" | "state"> {
extends Omit<InteractiveStatefulProps, "variant" | "state" | "children"> {
/** Left icon — always visible. */
icon: IconFunctionComponent;
/** Label text between icon and trailing indicator. */
children: string;
children: string | RichStr;
/** Whether the filter has an active selection. @default false */
active?: boolean;
@@ -68,9 +69,9 @@ function FilterButton({
<Interactive.Container type="button">
<div className="interactive-foreground flex flex-row items-center gap-1">
{iconWrapper(Icon, "lg", true)}
<span className="whitespace-nowrap font-main-ui-action">
<Text font="main-ui-action" color="inherit" nowrap>
{children}
</span>
</Text>
<div style={{ visibility: active ? "hidden" : "visible" }}>
{iconWrapper(ChevronIcon, "lg", true)}
</div>

View File

@@ -4,7 +4,12 @@ import {
type InteractiveStatefulProps,
type InteractiveStatefulInteraction,
} from "@opal/core";
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
import type {
ContainerSizeVariants,
ExtremaSizeVariants,
RichStr,
} from "@opal/types";
import { Text } from "@opal/components";
import type { InteractiveContainerRoundingVariant } from "@opal/core";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
@@ -28,17 +33,17 @@ type OpenButtonContentProps =
| {
foldable: true;
icon: IconFunctionComponent;
children: string;
children: string | RichStr;
}
| {
foldable?: false;
icon?: IconFunctionComponent;
children: string;
children: string | RichStr;
}
| {
foldable?: false;
icon: IconFunctionComponent;
children?: string;
children?: string | RichStr;
};
type OpenButtonVariant = "select-light" | "select-heavy" | "select-tinted";
@@ -101,14 +106,13 @@ function OpenButton({
const isLarge = size === "lg";
const labelEl = children ? (
<span
className={cn(
"whitespace-nowrap",
isLarge ? "font-main-ui-body" : "font-secondary-body"
)}
<Text
font={isLarge ? "main-ui-body" : "secondary-body"}
color="inherit"
nowrap
>
{children}
</span>
</Text>
) : null;
const button = (
@@ -177,7 +181,7 @@ function OpenButton({
side={tooltipSide}
sideOffset={4}
>
{resolvedTooltip}
<Text>{resolvedTooltip}</Text>
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
</TooltipPrimitive.Root>

View File

@@ -4,7 +4,12 @@ import {
useDisabled,
type InteractiveStatefulProps,
} from "@opal/core";
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
import type {
ContainerSizeVariants,
ExtremaSizeVariants,
RichStr,
} from "@opal/types";
import { Text } from "@opal/components";
import type { TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
@@ -26,19 +31,19 @@ type SelectButtonContentProps =
| {
foldable: true;
icon: IconFunctionComponent;
children: string;
children: string | RichStr;
rightIcon?: IconFunctionComponent;
}
| {
foldable?: false;
icon?: IconFunctionComponent;
children: string;
children: string | RichStr;
rightIcon?: IconFunctionComponent;
}
| {
foldable?: false;
icon: IconFunctionComponent;
children?: string;
children?: string | RichStr;
rightIcon?: IconFunctionComponent;
};
@@ -79,13 +84,10 @@ function SelectButton({
const isLarge = size === "lg";
const labelEl = children ? (
<span
className={cn(
"opal-select-button-label",
isLarge ? "font-main-ui-body" : "font-secondary-body"
)}
>
{children}
<span className="opal-select-button-label">
<Text font={isLarge ? "main-ui-body" : "secondary-body"} color="inherit">
{children}
</Text>
</span>
) : null;
@@ -137,7 +139,7 @@ function SelectButton({
side={tooltipSide}
sideOffset={4}
>
{resolvedTooltip}
<Text>{resolvedTooltip}</Text>
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
</TooltipPrimitive.Root>

View File

@@ -4,7 +4,9 @@ import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { SvgArrowRight, SvgChevronLeft, SvgChevronRight } from "@opal/icons";
import { containerSizeVariants } from "@opal/shared";
import type { WithoutStyles } from "@opal/types";
import type { RichStr, WithoutStyles } from "@opal/types";
import { Text } from "@opal/components";
import { toPlainString } from "@opal/components/text/InlineMarkdown";
import { cn } from "@opal/utils";
import * as PopoverPrimitive from "@radix-ui/react-popover";
import {
@@ -38,7 +40,7 @@ interface SimplePaginationProps
/** Hides the `currentPage/totalPages` summary text between arrows. Default: `false`. */
hidePages?: boolean;
/** Unit label shown after the summary (e.g. `"pages"`). Always has 4px spacing. */
units?: string;
units?: string | RichStr;
}
/**
@@ -63,7 +65,7 @@ interface CountPaginationProps
/** Hides the current page number between the arrows. Default: `false`. */
hidePages?: boolean;
/** Unit label shown after the total count (e.g. `"items"`). Always has 4px spacing. */
units?: string;
units?: string | RichStr;
}
/**
@@ -331,7 +333,9 @@ function PaginationSimple({
}: SimplePaginationProps) {
const handleChange = (page: number) => onChange?.(page);
const label = `${currentPage}/${totalPages}${units ? ` ${units}` : ""}`;
const label = `${currentPage}/${totalPages}${
units ? ` ${toPlainString(units)}` : ""
}`;
return (
<div {...props} className="flex items-center">
@@ -385,7 +389,16 @@ function PaginationCount({
{rangeStart}~{rangeEnd}
<span className={textClasses(size, "muted")}>of</span>
{totalItems}
{units && <span className="ml-1">{units}</span>}
{units && (
<span className="ml-1">
<Text
color="inherit"
font={size === "sm" ? "secondary-body" : "main-ui-muted"}
>
{units}
</Text>
</span>
)}
</span>
{/* Buttons: < [page] > */}

View File

@@ -1,6 +1,7 @@
import "@opal/components/tag/styles.css";
import type { IconFunctionComponent } from "@opal/types";
import type { IconFunctionComponent, RichStr } from "@opal/types";
import { Text } from "@opal/components";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
@@ -16,7 +17,7 @@ interface TagProps {
icon?: IconFunctionComponent;
/** Tag label text. */
title: string;
title: string | RichStr;
/** Color variant. Default: `"gray"`. */
color?: TagColor;
@@ -51,14 +52,13 @@ function Tag({ icon: Icon, title, color = "gray", size = "sm" }: TagProps) {
<Icon className={cn("opal-auxiliary-tag-icon", config.text)} />
</div>
)}
<span
className={cn(
"opal-auxiliary-tag-title px-[2px]",
size === "md" ? "font-secondary-body" : "font-figure-small-value",
config.text
)}
>
{title}
<span className={cn("opal-auxiliary-tag-title px-[2px]", config.text)}>
<Text
font={size === "md" ? "secondary-body" : "figure-small-value"}
color="inherit"
>
{title}
</Text>
</span>
</div>
);

View File

@@ -10,7 +10,7 @@ import type { RichStr } from "@opal/types";
const SAFE_PROTOCOL = /^https?:|^mailto:|^tel:/i;
const ALLOWED_ELEMENTS = ["p", "a", "strong", "em", "code", "del"];
const ALLOWED_ELEMENTS = ["p", "br", "a", "strong", "em", "code", "del"];
const INLINE_COMPONENTS = {
p: ({ children }: { children?: ReactNode }) => <>{children}</>,
@@ -41,6 +41,11 @@ interface InlineMarkdownProps {
}
export default function InlineMarkdown({ content }: InlineMarkdownProps) {
// Convert \n to CommonMark hard line breaks (two trailing spaces + newline).
// react-markdown renders these as <br />, which inherits the parent's
// line-height for font-appropriate spacing.
const normalized = content.replace(/\n/g, " \n");
return (
<ReactMarkdown
components={INLINE_COMPONENTS}
@@ -48,7 +53,7 @@ export default function InlineMarkdown({ content }: InlineMarkdownProps) {
unwrapDisallowed
remarkPlugins={[remarkGfm]}
>
{content}
{normalized}
</ReactMarkdown>
);
}

View File

@@ -90,15 +90,15 @@ import { markdown } from "@opal/utils";
</Text>
```
Supported syntax: `**bold**`, `*italic*`, `` `code` ``, `[link](url)`, `~~strikethrough~~`.
Supported syntax: `**bold**`, `*italic*`, `` `code` ``, `[link](url)`, `~~strikethrough~~`, `\n` (newline → `<br />`).
Markdown rendering uses `react-markdown` internally, restricted to inline elements only.
`http(s)` links open in a new tab; `mailto:` and `tel:` links open natively. Inline code
inherits the parent font size and switches to the monospace family.
**Note:** This is inline-only markdown. Multi-paragraph content (`"Hello\n\nWorld"`) will
collapse into a single run of text since paragraph wrappers are stripped. For block-level
markdown, use `MinimalMarkdown` instead.
Newlines (`\n`) are converted to `<br />` elements that inherit the parent's line-height,
so line spacing is proportional to the font size. For full block-level markdown (code blocks,
headings, lists), use `MinimalMarkdown` instead.
### Using `RichStr` in component props

View File

@@ -29,6 +29,7 @@ type TextFont =
| "figure-keystroke";
type TextColor =
| "inherit"
| "text-01"
| "text-02"
| "text-03"
@@ -60,6 +61,9 @@ interface TextProps
/** Prevent text wrapping. */
nowrap?: boolean;
/** Truncate text to N lines with ellipsis. `1` uses simple truncation; `2+` uses `-webkit-line-clamp`. */
maxLines?: number;
/** Plain string or `markdown()` for inline markdown. */
children?: string | RichStr;
}
@@ -89,7 +93,8 @@ const FONT_CONFIG: Record<TextFont, string> = {
"figure-keystroke": "font-figure-keystroke",
};
const COLOR_CONFIG: Record<TextColor, string> = {
const COLOR_CONFIG: Record<TextColor, string | null> = {
inherit: null,
"text-01": "text-text-01",
"text-02": "text-text-02",
"text-03": "text-text-03",
@@ -115,17 +120,29 @@ function Text({
color = "text-04",
as: Tag = "span",
nowrap,
maxLines,
children,
...rest
}: TextProps) {
const resolvedClassName = cn(
FONT_CONFIG[font],
COLOR_CONFIG[color],
nowrap && "whitespace-nowrap"
nowrap && "whitespace-nowrap",
maxLines === 1 && "truncate",
maxLines && maxLines > 1 && "overflow-hidden"
);
const style =
maxLines && maxLines > 1
? ({
display: "-webkit-box",
WebkitBoxOrient: "vertical",
WebkitLineClamp: maxLines,
} as React.CSSProperties)
: undefined;
return (
<Tag {...rest} className={resolvedClassName}>
<Tag {...rest} className={resolvedClassName} style={style}>
{children && resolveStr(children)}
</Tag>
);

View File

@@ -1,3 +1,5 @@
"use client";
import "@opal/core/animations/styles.css";
import React, { createContext, useContext, useState, useCallback } from "react";
import { cn } from "@opal/utils";

View File

@@ -1,3 +1,5 @@
"use client";
import "@opal/core/disabled/styles.css";
import React, { createContext, useContext } from "react";
import { Slot } from "@radix-ui/react-slot";

View File

@@ -0,0 +1,22 @@
import { cn } from "@opal/utils";
import type { IconProps } from "@opal/types";
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 37 46"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className={cn(className, "text-[#33C19E] dark:text-white")}
{...props}
>
<title>Bifrost</title>
<path
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
fill="currentColor"
/>
</svg>
);
export default SvgBifrost;

View File

@@ -24,6 +24,7 @@ export { default as SvgAzure } from "@opal/icons/azure";
export { default as SvgBarChart } from "@opal/icons/bar-chart";
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
export { default as SvgBell } from "@opal/icons/bell";
export { default as SvgBifrost } from "@opal/icons/bifrost";
export { default as SvgBlocks } from "@opal/icons/blocks";
export { default as SvgBookOpen } from "@opal/icons/book-open";
export { default as SvgBookmark } from "@opal/icons/bookmark";

View File

@@ -3,7 +3,11 @@
import { Button } from "@opal/components/buttons/button/components";
import type { ContainerSizeVariants } from "@opal/types";
import SvgEdit from "@opal/icons/edit";
import type { IconFunctionComponent } from "@opal/types";
import type { IconFunctionComponent, RichStr } from "@opal/types";
import {
resolveStr,
toPlainString,
} from "@opal/components/text/InlineMarkdown";
import { cn } from "@opal/utils";
import { useState } from "react";
@@ -35,10 +39,10 @@ interface ContentLgProps {
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
title: string | RichStr;
/** Optional description below the title. */
description?: string;
description?: string | RichStr;
/** Enable inline editing of the title. */
editable?: boolean;
@@ -96,18 +100,18 @@ function ContentLg({
ref,
}: ContentLgProps) {
const [editing, setEditing] = useState(false);
const [editValue, setEditValue] = useState(title);
const [editValue, setEditValue] = useState(toPlainString(title));
const config = CONTENT_LG_PRESETS[sizePreset];
function startEditing() {
setEditValue(title);
setEditValue(toPlainString(title));
setEditing(true);
}
function commit() {
const value = editValue.trim();
if (value && value !== title) onTitleChange?.(value);
if (value && value !== toPlainString(title)) onTitleChange?.(value);
setEditing(false);
}
@@ -157,7 +161,7 @@ function ContentLg({
onKeyDown={(e) => {
if (e.key === "Enter") commit();
if (e.key === "Escape") {
setEditValue(title);
setEditValue(toPlainString(title));
setEditing(false);
}
}}
@@ -174,9 +178,9 @@ function ContentLg({
)}
onClick={editable ? startEditing : undefined}
style={{ height: config.lineHeight }}
title={title}
title={toPlainString(title)}
>
{title}
{resolveStr(title)}
</span>
)}
@@ -199,9 +203,9 @@ function ContentLg({
)}
</div>
{description && (
{description && toPlainString(description) && (
<div className="opal-content-lg-description font-secondary-body text-text-03">
{description}
{resolveStr(description)}
</div>
)}
</div>

View File

@@ -7,7 +7,11 @@ import SvgAlertCircle from "@opal/icons/alert-circle";
import SvgAlertTriangle from "@opal/icons/alert-triangle";
import SvgEdit from "@opal/icons/edit";
import SvgXOctagon from "@opal/icons/x-octagon";
import type { IconFunctionComponent } from "@opal/types";
import type { IconFunctionComponent, RichStr } from "@opal/types";
import {
resolveStr,
toPlainString,
} from "@opal/components/text/InlineMarkdown";
import { cn } from "@opal/utils";
import { useRef, useState } from "react";
@@ -25,7 +29,6 @@ interface ContentMdPresetConfig {
iconColorClass: string;
titleFont: string;
lineHeight: string;
gap: string;
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
editButtonSize: ContainerSizeVariants;
editButtonPadding: string;
@@ -41,10 +44,10 @@ interface ContentMdProps {
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
title: string | RichStr;
/** Optional description text below the title. */
description?: string;
description?: string | RichStr;
/** Enable inline editing of the title. */
editable?: boolean;
@@ -70,6 +73,15 @@ interface ContentMdProps {
/** When `true`, the title color hooks into `Interactive`'s `--interactive-foreground` variable. */
withInteractive?: boolean;
/** Optional class name applied to the title element. */
titleClassName?: string;
/** Optional class name applied to the icon element. */
iconClassName?: string;
/** Content rendered below the description, indented to align with it. */
bottomChildren?: React.ReactNode;
/** Ref forwarded to the root `<div>`. */
ref?: React.Ref<HTMLDivElement>;
}
@@ -85,7 +97,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
iconColorClass: "text-text-04",
titleFont: "font-main-content-emphasis",
lineHeight: "1.5rem",
gap: "0.125rem",
editButtonSize: "sm",
editButtonPadding: "p-0",
optionalFont: "font-main-content-muted",
@@ -98,7 +109,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
iconColorClass: "text-text-03",
titleFont: "font-main-ui-action",
lineHeight: "1.25rem",
gap: "0.25rem",
editButtonSize: "xs",
editButtonPadding: "p-0",
optionalFont: "font-main-ui-muted",
@@ -111,7 +121,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
iconColorClass: "text-text-04",
titleFont: "font-secondary-action",
lineHeight: "1rem",
gap: "0.125rem",
editButtonSize: "2xs",
editButtonPadding: "p-0",
optionalFont: "font-secondary-action",
@@ -146,22 +155,25 @@ function ContentMd({
tag,
sizePreset = "main-ui",
withInteractive,
titleClassName,
iconClassName,
bottomChildren,
ref,
}: ContentMdProps) {
const [editing, setEditing] = useState(false);
const [editValue, setEditValue] = useState(title);
const [editValue, setEditValue] = useState(toPlainString(title));
const inputRef = useRef<HTMLInputElement>(null);
const config = CONTENT_MD_PRESETS[sizePreset];
function startEditing() {
setEditValue(title);
setEditValue(toPlainString(title));
setEditing(true);
}
function commit() {
const value = editValue.trim();
if (value && value !== title) onTitleChange?.(value);
if (value && value !== toPlainString(title)) onTitleChange?.(value);
setEditing(false);
}
@@ -170,7 +182,6 @@ function ContentMd({
ref={ref}
className="opal-content-md"
data-interactive={withInteractive || undefined}
style={{ gap: config.gap }}
>
<div
className="opal-content-md-header"
@@ -185,7 +196,11 @@ function ContentMd({
style={{ minHeight: config.lineHeight }}
>
<Icon
className={cn("opal-content-md-icon", config.iconColorClass)}
className={cn(
"opal-content-md-icon",
config.iconColorClass,
iconClassName
)}
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
@@ -215,7 +230,7 @@ function ContentMd({
onKeyDown={(e) => {
if (e.key === "Enter") commit();
if (e.key === "Escape") {
setEditValue(title);
setEditValue(toPlainString(title));
setEditing(false);
}
}}
@@ -228,13 +243,14 @@ function ContentMd({
"opal-content-md-title",
config.titleFont,
"text-text-04",
editable && "cursor-pointer"
editable && "cursor-pointer",
titleClassName
)}
title={title}
title={toPlainString(title)}
onClick={editable ? startEditing : undefined}
style={{ height: config.lineHeight }}
>
{title}
{resolveStr(title)}
</span>
)}
@@ -288,12 +304,19 @@ function ContentMd({
</div>
</div>
{description && (
{description && toPlainString(description) && (
<div
className="opal-content-md-description font-secondary-body text-text-03"
style={Icon ? { paddingLeft: config.descriptionIndent } : undefined}
>
{description}
{resolveStr(description)}
</div>
)}
{bottomChildren && (
<div
style={Icon ? { paddingLeft: config.descriptionIndent } : undefined}
>
{bottomChildren}
</div>
)}
</div>

View File

@@ -1,6 +1,10 @@
"use client";
import type { IconFunctionComponent } from "@opal/types";
import type { IconFunctionComponent, RichStr } from "@opal/types";
import {
resolveStr,
toPlainString,
} from "@opal/components/text/InlineMarkdown";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
@@ -30,7 +34,7 @@ interface ContentSmProps {
icon?: IconFunctionComponent;
/** Main title text (read-only — editing is not supported). */
title: string;
title: string | RichStr;
/** Size preset. Default: `"main-ui"`. */
sizePreset?: ContentSmSizePreset;
@@ -118,9 +122,9 @@ function ContentSm({
<span
className={cn("opal-content-sm-title", config.titleFont)}
style={{ height: config.lineHeight }}
title={title}
title={toPlainString(title)}
>
{title}
{resolveStr(title)}
</span>
</div>
);

View File

@@ -3,7 +3,11 @@
import { Button } from "@opal/components/buttons/button/components";
import type { ContainerSizeVariants } from "@opal/types";
import SvgEdit from "@opal/icons/edit";
import type { IconFunctionComponent } from "@opal/types";
import type { IconFunctionComponent, RichStr } from "@opal/types";
import {
resolveStr,
toPlainString,
} from "@opal/components/text/InlineMarkdown";
import { cn } from "@opal/utils";
import { useState } from "react";
@@ -41,10 +45,10 @@ interface ContentXlProps {
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
title: string | RichStr;
/** Optional description below the title. */
description?: string;
description?: string | RichStr;
/** Enable inline editing of the title. */
editable?: boolean;
@@ -116,18 +120,18 @@ function ContentXl({
ref,
}: ContentXlProps) {
const [editing, setEditing] = useState(false);
const [editValue, setEditValue] = useState(title);
const [editValue, setEditValue] = useState(toPlainString(title));
const config = CONTENT_XL_PRESETS[sizePreset];
function startEditing() {
setEditValue(title);
setEditValue(toPlainString(title));
setEditing(true);
}
function commit() {
const value = editValue.trim();
if (value && value !== title) onTitleChange?.(value);
if (value && value !== toPlainString(title)) onTitleChange?.(value);
setEditing(false);
}
@@ -214,7 +218,7 @@ function ContentXl({
onKeyDown={(e) => {
if (e.key === "Enter") commit();
if (e.key === "Escape") {
setEditValue(title);
setEditValue(toPlainString(title));
setEditing(false);
}
}}
@@ -231,9 +235,9 @@ function ContentXl({
)}
onClick={editable ? startEditing : undefined}
style={{ height: config.lineHeight }}
title={title}
title={toPlainString(title)}
>
{title}
{resolveStr(title)}
</span>
)}
@@ -256,9 +260,9 @@ function ContentXl({
)}
</div>
{description && (
{description && toPlainString(description) && (
<div className="opal-content-xl-description font-secondary-body text-text-03">
{description}
{resolveStr(description)}
</div>
)}
</div>

View File

@@ -1,3 +1,39 @@
// ---------------------------------------------------------------------------
// NOTE (@raunakab): Why Content uses resolveStr() instead of <Text>
//
// Content sub-components (ContentXl, ContentLg, ContentMd, ContentSm) render
// titles and descriptions inside styled <span> elements that carry CSS classes
// (e.g., `.opal-content-md-title`) for:
//
// 1. Truncation — `-webkit-box` + `-webkit-line-clamp` for single-line
// clamping with ellipsis. This requires the text to be a DIRECT child
// of the `-webkit-box` element. Wrapping it in a child <span> (which
// is what <Text> renders) breaks the clamping behavior.
//
// 2. Pixel-exact sizing — the wrapper <span> has an explicit `height`
// matching the font's `line-height`. Adding a child <Text> <span>
// inside creates a double-span where the inner element's line-height
// conflicts with the outer element's height, causing a ~4px vertical
// offset.
//
// 3. Interactive color overrides — CSS selectors like
// `.opal-content-md[data-interactive] .opal-content-md-title` set
// `color: var(--interactive-foreground)`. <Text> with `color="inherit"`
// can inherit this, but <Text> with any explicit color prop overrides
// it. And the wrapper <span> needs the CSS class for the selector to
// match — removing it breaks the cascade.
//
// 4. Horizontal padding — the title CSS class applies `padding: 0 0.125rem`
// (2px). Since <Text> uses WithoutStyles (no className/style), this
// padding cannot be applied to <Text> directly. A wrapper <div> was
// attempted but introduced additional layout conflicts.
//
// For these reasons, Content uses `resolveStr()` from InlineMarkdown.tsx to
// handle `string | RichStr` rendering. `resolveStr()` returns a ReactNode
// that slots directly into the existing single <span> — no extra wrapper,
// no layout conflicts, pixel-exact match with main.
// ---------------------------------------------------------------------------
import "@opal/layouts/content/styles.css";
import {
ContentSm,
@@ -17,7 +53,7 @@ import {
type ContentMdProps,
} from "@opal/layouts/content/ContentMd";
import type { TagProps } from "@opal/components/tag/components";
import type { IconFunctionComponent } from "@opal/types";
import type { IconFunctionComponent, RichStr } from "@opal/types";
import { widthVariants } from "@opal/shared";
import type { ExtremaSizeVariants } from "@opal/types";
@@ -39,10 +75,10 @@ interface ContentBaseProps {
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
title: string | RichStr;
/** Optional description below the title. */
description?: string;
description?: string | RichStr;
/** Enable inline editing of the title. */
editable?: boolean;
@@ -67,6 +103,12 @@ interface ContentBaseProps {
/** Ref forwarded to the root `<div>` of the resolved layout. */
ref?: React.Ref<HTMLDivElement>;
/** Optional class name applied to the icon element. */
iconClassName?: string;
/** Content rendered below the description, indented to align with it (MdContent only). */
bottomChildren?: React.ReactNode;
}
// ---------------------------------------------------------------------------
@@ -102,6 +144,8 @@ type MdContentProps = ContentBaseProps & {
auxIcon?: "info-gray" | "info-blue" | "warning" | "error";
/** Tag rendered beside the title. */
tag?: TagProps;
/** Optional class name applied to the title element. */
titleClassName?: string;
};
/** ContentSm does not support descriptions or inline editing. */

View File

@@ -1,4 +1,5 @@
import type { IconFunctionComponent } from "@opal/types";
import type { IconFunctionComponent, RichStr } from "@opal/types";
import { Text } from "@opal/components";
// ---------------------------------------------------------------------------
// Types
@@ -9,10 +10,10 @@ interface IllustrationContentProps {
illustration?: IconFunctionComponent;
/** Main title text, center-aligned. Uses `font-main-content-emphasis`. */
title: string;
title: string | RichStr;
/** Optional description below the title, center-aligned. Uses `font-secondary-body`. */
description?: string;
description?: string | RichStr;
}
// ---------------------------------------------------------------------------
@@ -68,9 +69,13 @@ function IllustrationContent({
/>
)}
<div className="flex flex-col items-center text-center">
<p className="font-main-content-emphasis text-text-04">{title}</p>
<Text font="main-content-emphasis" color="text-04" as="p">
{title}
</Text>
{description && (
<p className="font-secondary-body text-text-03">{description}</p>
<Text font="secondary-body" color="text-03" as="p">
{description}
</Text>
)}
</div>
</div>

View File

@@ -30,8 +30,11 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import Text from "@/refresh-components/texts/Text";
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
import Message from "@/refresh-components/messages/Message";
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
import { useBillingInformation } from "@/hooks/useBillingInformation";
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
const route = ADMIN_ROUTES.API_KEYS;
@@ -44,6 +47,11 @@ function Main() {
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
const canCreateKeys = useCloudSubscription();
const { data: billingData } = useBillingInformation();
const isTrialing =
billingData !== undefined &&
hasActiveSubscription(billingData) &&
billingData.status === BillingStatus.TRIALING;
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
@@ -75,6 +83,16 @@ function Main() {
const introSection = (
<div className="flex flex-col items-start gap-4">
{isTrialing && (
<Message
static
warning
close={false}
className="w-full"
text="Upgrade to a paid plan to create API keys."
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
)}
<Text as="p">
API Keys allow you to access Onyx APIs programmatically.
{canCreateKeys
@@ -85,23 +103,9 @@ function Main() {
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
Create API Key
</CreateButton>
) : (
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
<div className="flex items-center gap-1.5">
<Text as="p" text04>
Upgrade to a paid plan to create API keys.
</Text>
<Button
variant="none"
prominence="tertiary"
size="2xs"
icon={SvgInfo}
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
/>
</div>
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
</div>
)}
) : isTrialing ? (
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
) : null}
</div>
);

View File

@@ -0,0 +1,387 @@
/**
* Tests for BillingPage handleBillingReturn retry logic.
*
* The retry logic retries claimLicense up to 3 times with 2s backoff
* when returning from a Stripe checkout session. This prevents the user
* from getting stranded when the Stripe webhook fires concurrently with
* the browser redirect and the license isn't ready yet.
*/
import React from "react";
import { render, screen, waitFor } from "@tests/setup/test-utils";
import { act } from "@testing-library/react";
// ---- Stable mock objects (must be named with mock* prefix for jest hoisting) ----
// useRouter and useSearchParams must return the SAME reference each call, otherwise
// React's useEffect sees them as changed and re-runs the effect on every render.
const mockRouter = {
replace: jest.fn() as jest.Mock,
refresh: jest.fn() as jest.Mock,
};
const mockSearchParams = {
get: jest.fn() as jest.Mock,
};
const mockClaimLicense = jest.fn() as jest.Mock;
const mockRefreshBilling = jest.fn() as jest.Mock;
const mockRefreshLicense = jest.fn() as jest.Mock;
// ---- Mocks ----
jest.mock("next/navigation", () => ({
useRouter: () => mockRouter,
useSearchParams: () => mockSearchParams,
}));
jest.mock("@/layouts/settings-layouts", () => ({
Root: ({ children }: { children: React.ReactNode }) => (
<div data-testid="settings-root">{children}</div>
),
Header: () => <div data-testid="settings-header" />,
Body: ({ children }: { children: React.ReactNode }) => (
<div data-testid="settings-body">{children}</div>
),
}));
jest.mock("@/layouts/general-layouts", () => ({
Section: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
}));
jest.mock("@opal/icons", () => ({
SvgArrowUpCircle: () => <svg />,
SvgWallet: () => <svg />,
}));
jest.mock("./PlansView", () => ({
__esModule: true,
default: () => <div data-testid="plans-view" />,
}));
jest.mock("./CheckoutView", () => ({
__esModule: true,
default: () => <div data-testid="checkout-view" />,
}));
jest.mock("./BillingDetailsView", () => ({
__esModule: true,
default: () => <div data-testid="billing-details-view" />,
}));
jest.mock("./LicenseActivationCard", () => ({
__esModule: true,
default: () => <div data-testid="license-activation-card" />,
}));
jest.mock("@/refresh-components/messages/Message", () => ({
__esModule: true,
default: ({
text,
description,
onClose,
}: {
text: string;
description?: string;
onClose?: () => void;
}) => (
<div data-testid="activating-banner">
<span data-testid="activating-banner-text">{text}</span>
{description && (
<span data-testid="activating-banner-description">{description}</span>
)}
{onClose && (
<button data-testid="activating-banner-close" onClick={onClose}>
Close
</button>
)}
</div>
),
}));
jest.mock("@/lib/billing", () => ({
useBillingInformation: jest.fn(),
useLicense: jest.fn(),
hasActiveSubscription: jest.fn().mockReturnValue(false),
claimLicense: (...args: unknown[]) => mockClaimLicense(...args),
}));
jest.mock("@/lib/constants", () => ({
NEXT_PUBLIC_CLOUD_ENABLED: false,
}));
// ---- Import after mocks ----
import BillingPage from "./page";
import { useBillingInformation, useLicense } from "@/lib/billing";
// ---- Test helpers ----
function setupHooks() {
(useBillingInformation as jest.Mock).mockReturnValue({
data: null,
isLoading: false,
error: null,
refresh: mockRefreshBilling,
});
(useLicense as jest.Mock).mockReturnValue({
data: null,
isLoading: false,
refresh: mockRefreshLicense,
});
}
// ---- Tests ----
describe("BillingPage — handleBillingReturn retry logic", () => {
beforeEach(() => {
jest.clearAllMocks();
jest.useFakeTimers();
setupHooks();
// Default: no billing-return params
mockSearchParams.get.mockReturnValue(null);
// Clear any activating state from prior tests
sessionStorage.clear();
});
afterEach(() => {
jest.useRealTimers();
jest.restoreAllMocks();
});
test("calls claimLicense once and refreshes on first-attempt success", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_test_123" : null
);
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
expect(mockClaimLicense).toHaveBeenCalledWith("cs_test_123");
});
expect(mockRouter.refresh).toHaveBeenCalled();
expect(mockRefreshBilling).toHaveBeenCalled();
// URL cleaned up after checkout return
expect(mockRouter.replace).toHaveBeenCalledWith("/admin/billing", {
scroll: false,
});
});
test("retries after first failure and succeeds on second attempt", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_retry_test" : null
);
mockClaimLicense
.mockRejectedValueOnce(new Error("License not ready yet"))
.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(2);
});
// On eventual success, router and billing should be refreshed
expect(mockRouter.refresh).toHaveBeenCalled();
expect(mockRefreshBilling).toHaveBeenCalled();
});
test("retries all 3 times then navigates to details even on total failure", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_all_fail" : null
);
// All 3 attempts fail
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(3);
});
// User stays on plans view with the activating banner
await waitFor(() => {
expect(screen.getByTestId("plans-view")).toBeInTheDocument();
});
// refreshBilling still fires so billing state is up to date
expect(mockRefreshBilling).toHaveBeenCalled();
// Failure is logged
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("Failed to sync license after billing return"),
expect.any(Error)
);
consoleSpy.mockRestore();
});
test("calls claimLicense without session_id on portal_return", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "portal_return" ? "true" : null
);
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
// No session_id for portal returns — called with undefined
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
});
expect(mockRefreshBilling).toHaveBeenCalled();
});
test("does not call claimLicense when no billing-return params present", async () => {
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(mockClaimLicense).not.toHaveBeenCalled();
});
test("shows activating banner and sets sessionStorage on 3x retry failure", async () => {
mockSearchParams.get.mockImplementation((key: string) =>
key === "session_id" ? "cs_all_fail" : null
);
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
const consoleSpy = jest
.spyOn(console, "error")
.mockImplementation(() => {});
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
await waitFor(() => {
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
});
expect(screen.getByTestId("activating-banner-text")).toHaveTextContent(
"Your license is still activating"
);
expect(
sessionStorage.getItem("billing_license_activating_until")
).not.toBeNull();
consoleSpy.mockRestore();
});
test("banner not rendered when no activating state", async () => {
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
});
test("banner shown on mount when sessionStorage key is set and not expired", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
// Flush React effects — banner is visible from lazy state init, no timer advancement needed
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
});
test("banner not shown on mount when sessionStorage key is expired", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() - 1000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
await act(async () => {
await jest.runAllTimersAsync();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
});
test("poll calls claimLicense after 15s and clears banner on success", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
// Poll attempt succeeds
mockClaimLicense.mockResolvedValueOnce({ success: true });
render(<BillingPage />);
// Flush effects — banner visible from lazy state init
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
// Advance past one poll interval (15s)
await act(async () => {
await jest.advanceTimersByTimeAsync(15_000);
});
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
expect(mockRefreshBilling).toHaveBeenCalled();
expect(mockRefreshLicense).toHaveBeenCalled();
expect(mockRouter.refresh).toHaveBeenCalled();
});
test("close button removes banner and clears sessionStorage", async () => {
sessionStorage.setItem(
"billing_license_activating_until",
String(Date.now() + 120_000)
);
mockSearchParams.get.mockReturnValue(null);
render(<BillingPage />);
// Flush effects — banner visible from lazy state init
await act(async () => {});
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
const closeButton = screen.getByTestId("activating-banner-close");
await act(async () => {
closeButton.click();
});
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
expect(
sessionStorage.getItem("billing_license_activating_until")
).toBeNull();
});
});

View File

@@ -17,6 +17,7 @@ import {
} from "@/lib/billing";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { useUser } from "@/providers/UserProvider";
import Message from "@/refresh-components/messages/Message";
import PlansView from "./PlansView";
import CheckoutView from "./CheckoutView";
@@ -24,6 +25,9 @@ import BillingDetailsView from "./BillingDetailsView";
import LicenseActivationCard from "./LicenseActivationCard";
import "./billing.css";
// sessionStorage key: value is a unix-ms expiry timestamp
const BILLING_ACTIVATING_KEY = "billing_license_activating_until";
// ----------------------------------------------------------------------------
// Types
// ----------------------------------------------------------------------------
@@ -105,6 +109,7 @@ export default function BillingPage() {
const [transitionType, setTransitionType] = useState<
"expand" | "collapse" | "fade"
>("fade");
const [isActivating, setIsActivating] = useState<boolean>(false);
const {
data: billingData,
@@ -155,6 +160,17 @@ export default function BillingPage() {
view,
]);
// Read activating state from sessionStorage after mount (avoids SSR hydration mismatch)
useEffect(() => {
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
if (!raw) return;
if (Number(raw) > Date.now()) {
setIsActivating(true);
} else {
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
}
}, []);
// Show license activation card when there's a Stripe error
useEffect(() => {
if (hasStripeError && !showLicenseActivationInput) {
@@ -172,24 +188,96 @@ export default function BillingPage() {
router.replace("/admin/billing", { scroll: false });
let cancelled = false;
const handleBillingReturn = async () => {
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
try {
// After checkout, exchange session_id for license; after portal, re-sync license
await claimLicense(sessionId ?? undefined);
refreshLicense();
// Refresh the page to update settings (including ee_features_enabled)
router.refresh();
// Navigate to billing details now that the license is active
changeView("details");
} catch (error) {
console.error("Failed to sync license after billing return:", error);
// Retry up to 3 times with 2s backoff. The license may not be available
// immediately if the Stripe webhook hasn't finished processing yet
// (redirect and webhook fire nearly simultaneously).
let lastError: Error | null = null;
for (let attempt = 0; attempt < 3; attempt++) {
if (cancelled) return;
try {
// After checkout, exchange session_id for license; after portal, re-sync license
await claimLicense(sessionId ?? undefined);
if (cancelled) return;
refreshLicense();
// Refresh the page to update settings (including ee_features_enabled)
router.refresh();
// Navigate to billing details now that the license is active
changeView("details");
lastError = null;
break;
} catch (err) {
lastError = err instanceof Error ? err : new Error("Unknown error");
if (attempt < 2) {
await new Promise((resolve) => setTimeout(resolve, 2000));
}
}
}
if (cancelled) return;
if (lastError) {
console.error(
"Failed to sync license after billing return:",
lastError
);
// Show an activating banner on the plans view and keep retrying in the background.
sessionStorage.setItem(
BILLING_ACTIVATING_KEY,
String(Date.now() + 120_000)
);
setIsActivating(true);
changeView("plans");
}
}
refreshBilling();
if (!cancelled) refreshBilling();
};
handleBillingReturn();
}, [searchParams, router, refreshBilling, refreshLicense]);
return () => {
cancelled = true;
};
// changeView intentionally omitted: it only calls stable state setters and the
// effect runs at most once (when session_id/portal_return params are present).
}, [searchParams, router, refreshBilling, refreshLicense]); // eslint-disable-line react-hooks/exhaustive-deps
// Poll every 15s while activating, up to 2 minutes, to detect when the license arrives.
useEffect(() => {
if (!isActivating) return;
let requestInFlight = false;
const intervalId = setInterval(async () => {
if (requestInFlight) return;
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
if (!raw || Number(raw) <= Date.now()) {
// Expired — stop immediately without waiting for React cleanup
clearInterval(intervalId);
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
return;
}
requestInFlight = true;
try {
await claimLicense(undefined);
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
refreshLicense();
refreshBilling();
router.refresh();
changeView("details");
} catch (err) {
// License not ready yet — keep polling. Log so unexpected failures
// (network errors, 500s) are distinguishable from expected 404s.
console.debug("License activation poll: will retry", err);
} finally {
requestInFlight = false;
}
}, 15_000);
return () => clearInterval(intervalId);
}, [isActivating]); // eslint-disable-line react-hooks/exhaustive-deps
const handleRefresh = async () => {
await Promise.all([
@@ -386,6 +474,22 @@ export default function BillingPage() {
/>
<SettingsLayouts.Body>
<div className="flex flex-col items-center gap-6">
{isActivating && (
<Message
static
warning
large
text="Your license is still activating"
description="Your license is being processed. You'll be taken to billing details automatically once confirmed."
icon
close
onClose={() => {
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
setIsActivating(false);
}}
className="w-full"
/>
)}
{renderContent()}
{renderFooter()}
</div>

View File

@@ -1,11 +1,12 @@
"use client";
import { useState, useMemo } from "react";
import { useState, useMemo, useEffect } from "react";
import useSWR from "swr";
import Text from "@/refresh-components/texts/Text";
import { Select } from "@/refresh-components/cards";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
import {
@@ -17,9 +18,16 @@ import {
ImageGenerationConfigView,
setDefaultImageGenerationConfig,
unsetDefaultImageGenerationConfig,
deleteImageGenerationConfig,
} from "@/lib/configuration/imageConfigurationService";
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
import Message from "@/refresh-components/messages/Message";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { Button } from "@opal/components";
import { SvgSlash, SvgUnplug } from "@opal/icons";
const NO_DEFAULT_VALUE = "__none__";
export default function ImageGenerationContent() {
const {
@@ -47,6 +55,11 @@ export default function ImageGenerationContent() {
);
const [editConfig, setEditConfig] =
useState<ImageGenerationConfigView | null>(null);
const [disconnectProvider, setDisconnectProvider] =
useState<ImageProvider | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const connectedProviderIds = useMemo(() => {
return new Set(configs.map((c) => c.image_provider_id));
@@ -115,6 +128,29 @@ export default function ImageGenerationContent() {
modal.toggle(true);
};
const handleDisconnect = async () => {
if (!disconnectProvider) return;
try {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
await setDefaultImageGenerationConfig(replacementProviderId);
}
await deleteImageGenerationConfig(disconnectProvider.image_provider_id);
toast.success(`${disconnectProvider.title} disconnected`);
refetchConfigs();
refetchProviders();
} catch (error) {
console.error("Failed to disconnect image generation provider:", error);
toast.error(
error instanceof Error ? error.message : "Failed to disconnect"
);
} finally {
setDisconnectProvider(null);
setReplacementProviderId(null);
}
};
const handleModalSuccess = () => {
toast.success("Provider configured successfully");
setEditConfig(null);
@@ -130,6 +166,36 @@ export default function ImageGenerationContent() {
);
}
// Compute replacement options when disconnecting an active provider
const isDisconnectingDefault =
disconnectProvider &&
defaultConfig?.image_provider_id === disconnectProvider.image_provider_id;
// Group connected replacement models by provider (excluding the model being disconnected)
const replacementGroups = useMemo(() => {
if (!disconnectProvider) return [];
return IMAGE_PROVIDER_GROUPS.map((group) => ({
...group,
providers: group.providers.filter(
(p) =>
p.image_provider_id !== disconnectProvider.image_provider_id &&
connectedProviderIds.has(p.image_provider_id)
),
})).filter((g) => g.providers.length > 0);
}, [disconnectProvider, connectedProviderIds]);
const needsReplacement = !!isDisconnectingDefault;
const hasReplacements = replacementGroups.length > 0;
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && !replacementProviderId && hasReplacements) {
const firstGroup = replacementGroups[0];
const firstModel = firstGroup?.providers[0];
if (firstModel) setReplacementProviderId(firstModel.image_provider_id);
}
}, [disconnectProvider]); // eslint-disable-line react-hooks/exhaustive-deps
return (
<>
<div className="flex flex-col gap-6">
@@ -175,6 +241,11 @@ export default function ImageGenerationContent() {
onSelect={() => handleSelect(provider)}
onDeselect={() => handleDeselect(provider)}
onEdit={() => handleEdit(provider)}
onDisconnect={
getStatus(provider) !== "disconnected"
? () => setDisconnectProvider(provider)
: undefined
}
/>
))}
</div>
@@ -182,6 +253,105 @@ export default function ImageGenerationContent() {
))}
</div>
{disconnectProvider && (
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectProvider.title}`}
description="This will remove the stored credentials for this provider."
onClose={() => {
setDisconnectProvider(null);
setReplacementProviderId(null);
}}
submit={
<Button
variant="danger"
onClick={() => void handleDisconnect()}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</Button>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" text03>
<b>{disconnectProvider.title}</b> is currently the default
image generation model. Session history will be preserved.
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" text04>
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => setReplacementProviderId(v)}
>
<InputSelect.Trigger placeholder="Select a replacement model" />
<InputSelect.Content>
{replacementGroups.map((group) => (
<InputSelect.Group key={group.name}>
<InputSelect.Label>{group.name}</InputSelect.Label>
{group.providers.map((p) => (
<InputSelect.Item
key={p.image_provider_id}
value={p.image_provider_id}
icon={() => (
<ProviderIcon
provider={p.provider_name}
size={16}
/>
)}
>
{p.title}
</InputSelect.Item>
))}
</InputSelect.Group>
))}
<InputSelect.Separator />
<InputSelect.Item
value={NO_DEFAULT_VALUE}
icon={SvgSlash}
>
<span>
<b>No Default</b>
<span className="text-text-03">
{" "}
(Disable Image Generation)
</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" text03>
<b>{disconnectProvider.title}</b> is currently the default
image generation model.
</Text>
<Text as="p" text03>
Connect another provider to continue using image generation.
</Text>
</>
)
) : (
<>
<Text as="p" text03>
<b>{disconnectProvider.title}</b> models will no longer be used
to generate images.
</Text>
<Text as="p" text03>
Session history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
)}
{activeProvider && (
<modal.Provider>
<ImageGenerationConnectionModal

View File

@@ -23,6 +23,7 @@ import {
BedrockModelResponse,
LMStudioModelResponse,
LiteLLMProxyModelResponse,
BifrostModelResponse,
ModelConfiguration,
LLMProviderName,
BedrockFetchParams,
@@ -30,8 +31,9 @@ import {
LMStudioFetchParams,
OpenRouterFetchParams,
LiteLLMProxyFetchParams,
BifrostFetchParams,
} from "@/interfaces/llm";
import { SvgAws, SvgOpenrouter } from "@opal/icons";
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
// Aggregator providers that host models from multiple vendors
export const AGGREGATOR_PROVIDERS = new Set([
@@ -41,6 +43,7 @@ export const AGGREGATOR_PROVIDERS = new Set([
"ollama_chat",
"lm_studio",
"litellm_proxy",
"bifrost",
"vertex_ai",
]);
@@ -78,6 +81,7 @@ export const getProviderIcon = (
bedrock_converse: SvgAws,
openrouter: SvgOpenrouter,
litellm_proxy: LiteLLMIcon,
bifrost: SvgBifrost,
vertex_ai: GeminiIcon,
};
@@ -263,8 +267,11 @@ export const fetchOpenRouterModels = async (
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch {
// ignore JSON parsing errors
} catch (jsonError) {
console.warn(
"Failed to parse OpenRouter model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
@@ -319,8 +326,11 @@ export const fetchLMStudioModels = async (
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch {
// ignore JSON parsing errors
} catch (jsonError) {
console.warn(
"Failed to parse LM Studio model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
@@ -343,6 +353,64 @@ export const fetchLMStudioModels = async (
}
};
/**
* Fetches Bifrost models directly without any form state dependencies.
* Uses snake_case params to match API structure.
*/
export const fetchBifrostModels = async (
params: BifrostFetchParams
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
const apiBase = params.api_base;
if (!apiBase) {
return { models: [], error: "API Base is required" };
}
try {
const response = await fetch("/api/admin/llm/bifrost/available-models", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_base: apiBase,
api_key: params.api_key,
provider_name: params.provider_name,
}),
signal: params.signal,
});
if (!response.ok) {
let errorMessage = "Failed to fetch models";
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch (jsonError) {
console.warn(
"Failed to parse Bifrost model fetch error response",
jsonError
);
}
return { models: [], error: errorMessage };
}
const data: BifrostModelResponse[] = await response.json();
const models: ModelConfiguration[] = data.map((modelData) => ({
name: modelData.name,
display_name: modelData.display_name,
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: modelData.supports_reasoning,
}));
return { models };
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
return { models: [], error: errorMessage };
}
};
/**
* Fetches LiteLLM Proxy models directly without any form state dependencies.
* Uses snake_case params to match API structure.
@@ -456,6 +524,13 @@ export const fetchModels = async (
provider_name: formValues.name,
signal,
});
case LLMProviderName.BIFROST:
return fetchBifrostModels({
api_base: formValues.api_base,
api_key: formValues.api_key,
provider_name: formValues.name,
signal,
});
default:
return { models: [], error: `Unknown provider: ${providerName}` };
}
@@ -469,6 +544,7 @@ export function canProviderFetchModels(providerName?: string) {
case LLMProviderName.LM_STUDIO:
case LLMProviderName.OPENROUTER:
case LLMProviderName.LITELLM_PROXY:
case LLMProviderName.BIFROST:
return true;
default:
return false;

View File

@@ -7,7 +7,9 @@ import {
FailedConnectorIndexingStatus,
ValidStatuses,
} from "@/lib/types";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import Title from "@/components/ui/title";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
@@ -199,45 +201,30 @@ export default function UpgradingPage({
/>
)}
<Text className="my-4">
{futureEmbeddingModel.switchover_type === "active_only" ? (
<>
The table below shows the re-indexing progress of active
(non-paused) connectors. Once all active connectors have
been re-indexed successfully, the new model will be used
for all search queries. Paused connectors will continue
to be indexed in the background but won&apos;t block the
switchover. Until then, we will use the old model so
that no downtime is necessary during this transition.
<br />
Note: User file re-indexing progress is not shown. You
will see this page until all active connectors are
re-indexed!
</>
) : (
<>
The table below shows the re-indexing progress of all
existing connectors. Once all connectors have been
re-indexed successfully, the new model will be used for
all search queries. Until then, we will use the old
model so that no downtime is necessary during this
transition.
<br />
Note: User file re-indexing progress is not shown. You
will see this page until all user files are re-indexed!
</>
)}
<Spacer rem={1} />
<Text as="p">
{futureEmbeddingModel.switchover_type === "active_only"
? markdown(
"The table below shows the re-indexing progress of active (non-paused) connectors. Once all active connectors have been re-indexed successfully, the new model will be used for all search queries. Paused connectors will continue to be indexed in the background but won't block the switchover. Until then, we will use the old model so that no downtime is necessary during this transition.\nNote: User file re-indexing progress is not shown. You will see this page until all active connectors are re-indexed!"
)
: markdown(
"The table below shows the re-indexing progress of all existing connectors. Once all connectors have been re-indexed successfully, the new model will be used for all search queries. Until then, we will use the old model so that no downtime is necessary during this transition.\nNote: User file re-indexing progress is not shown. You will see this page until all user files are re-indexed!"
)}
</Text>
<Spacer rem={1} />
{sortedReindexingProgress ? (
<>
{futureEmbeddingModel.switchover_type === "active_only" &&
!hasVisibleReindexingProgress && (
<Text className="text-text-700 mt-4">
All connectors are currently paused, so none are
blocking the switchover. Paused connectors will keep
re-indexing in the background.
</Text>
<>
<Spacer rem={1} />
<Text as="p">
All connectors are currently paused, so none are
blocking the switchover. Paused connectors will
keep re-indexing in the background.
</Text>
</>
)}
{hasVisibleReindexingProgress && (
<ReindexingProgressTable

View File

@@ -3,7 +3,7 @@
import { ThreeDotsLoader } from "@/components/Loading";
import { errorHandlingFetcher } from "@/lib/fetcher";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Title from "@/components/ui/title";
import { Button } from "@opal/components";
import useSWR from "swr";
@@ -107,8 +107,10 @@ function Main() {
<div className="px-1 w-full rounded-lg">
<div className="space-y-4">
<div>
<Text className="font-semibold">Multipass Indexing</Text>
<Text className="text-text-700">
<Text as="p" font="main-ui-action">
Multipass Indexing
</Text>
<Text as="p">
{searchSettings.multipass_indexing
? "Enabled"
: "Disabled"}
@@ -116,8 +118,10 @@ function Main() {
</div>
<div>
<Text className="font-semibold">Contextual RAG</Text>
<Text className="text-text-700">
<Text as="p" font="main-ui-action">
Contextual RAG
</Text>
<Text as="p">
{searchSettings.enable_contextual_rag
? "Enabled"
: "Disabled"}

View File

@@ -1,5 +1,6 @@
"use client";
import { markdown } from "@opal/utils";
import Image from "next/image";
import { FunctionComponent, useState, useEffect } from "react";
import {
@@ -436,22 +437,9 @@ export default function VoiceProviderSetupModal({
{providerType === "azure" && (
<Vertical
title="Target URI"
subDescription={
<>
Paste the endpoint shown in{" "}
<a
href="https://portal.azure.com/"
target="_blank"
rel="noopener noreferrer"
className="underline"
>
Azure Portal (Keys and Endpoint)
</a>
. Onyx extracts the speech region from this URL. Examples:
https://westus.api.cognitive.microsoft.com/ or
https://westus.tts.speech.microsoft.com/.
</>
}
subDescription={markdown(
"Paste the endpoint shown in [Azure Portal (Keys and Endpoint)](https://portal.azure.com/). Onyx extracts the speech region from this URL. Examples: https://westus.api.cognitive.microsoft.com/ or https://westus.tts.speech.microsoft.com/."
)}
nonInteractive
>
<InputTypeIn
@@ -503,24 +491,14 @@ export default function VoiceProviderSetupModal({
{mode === "tts" && (
<Vertical
title="Voice"
subDescription={
<>
This voice will be used for spoken responses. See full list
of supported languages and voices at{" "}
<a
href={
PROVIDER_VOICE_DOCS_URLS[providerType]?.url ??
PROVIDER_DOCS_URLS[providerType]
}
target="_blank"
rel="noopener noreferrer"
className="underline"
>
{PROVIDER_VOICE_DOCS_URLS[providerType]?.label ?? label}
</a>
.
</>
}
subDescription={markdown(
`This voice will be used for spoken responses. See full list of supported languages and voices at [${
PROVIDER_VOICE_DOCS_URLS[providerType]?.label ?? label
}](${
PROVIDER_VOICE_DOCS_URLS[providerType]?.url ??
PROVIDER_DOCS_URLS[providerType]
}).`
)}
nonInteractive
>
<InputComboBox

View File

@@ -1,32 +1,25 @@
"use client";
import Image from "next/image";
import { useMemo, useState, useReducer } from "react";
import { useEffect, useMemo, useState, useReducer } from "react";
import { InfoIcon } from "@/components/icons/icons";
import Text from "@/refresh-components/texts/Text";
import { Select } from "@/refresh-components/cards";
import { Section } from "@/layouts/general-layouts";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Content } from "@opal/layouts";
import useSWR from "swr";
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
import { ThreeDotsLoader } from "@/components/Loading";
import { Callout } from "@/components/ui/callout";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { Disabled } from "@opal/core";
import { cn } from "@/lib/utils";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgCheckSquare,
SvgEdit,
SvgGlobe,
SvgOnyxLogo,
SvgX,
} from "@opal/icons";
import { toast } from "@/hooks/useToast";
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
import { Button as OpalButton } from "@opal/components";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
const route = ADMIN_ROUTES.WEB_SEARCH;
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import {
SEARCH_PROVIDERS_URL,
SEARCH_PROVIDER_DETAILS,
@@ -58,6 +51,10 @@ import {
} from "@/app/admin/configuration/web-search/WebProviderModalReducer";
import { connectProviderFlow } from "@/app/admin/configuration/web-search/connectProviderFlow";
const NO_DEFAULT_VALUE = "__none__";
const route = ADMIN_ROUTES.WEB_SEARCH;
interface WebSearchProviderView {
id: number;
name: string;
@@ -76,27 +73,151 @@ interface WebContentProviderView {
has_api_key: boolean;
}
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
isHovered: boolean;
onMouseEnter: () => void;
onMouseLeave: () => void;
children: React.ReactNode;
interface DisconnectTargetState {
id: number;
label: string;
category: "search" | "content";
providerType: string;
}
function HoverIconButton({
isHovered,
onMouseEnter,
onMouseLeave,
children,
...buttonProps
}: HoverIconButtonProps) {
function WebSearchDisconnectModal({
disconnectTarget,
searchProviders,
contentProviders,
replacementProviderId,
onReplacementChange,
onClose,
onDisconnect,
}: {
disconnectTarget: DisconnectTargetState;
searchProviders: WebSearchProviderView[];
contentProviders: WebContentProviderView[];
replacementProviderId: string | null;
onReplacementChange: (id: string | null) => void;
onClose: () => void;
onDisconnect: () => void;
}) {
const isSearch = disconnectTarget.category === "search";
// Determine if the target is currently the active/selected provider
const isActive = isSearch
? searchProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
false
: contentProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
false;
// Find other configured providers as replacements
const replacementOptions = isSearch
? searchProviders.filter(
(p) => p.id !== disconnectTarget.id && p.id > 0 && p.has_api_key
)
: contentProviders.filter(
(p) =>
p.id !== disconnectTarget.id &&
p.provider_type !== "onyx_web_crawler" &&
p.id > 0 &&
p.has_api_key
);
const needsReplacement = isActive;
const hasReplacements = replacementOptions.length > 0;
const getLabel = (p: { name: string; provider_type: string }) => {
if (isSearch) {
const details =
SEARCH_PROVIDER_DETAILS[p.provider_type as WebSearchProviderType];
return details?.label ?? p.name ?? p.provider_type;
}
const details = CONTENT_PROVIDER_DETAILS[p.provider_type];
return details?.label ?? p.name ?? p.provider_type;
};
const categoryLabel = isSearch ? "search engine" : "web crawler";
const featureLabel = isSearch ? "web search" : "web crawling";
const disableLabel = isSearch ? "Disable Web Search" : "Disable Web Crawling";
// Auto-select first replacement when modal opens
useEffect(() => {
if (needsReplacement && hasReplacements && !replacementProviderId) {
const first = replacementOptions[0];
if (first) onReplacementChange(String(first.id));
}
}, []); // eslint-disable-line react-hooks/exhaustive-deps
return (
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
{/* TODO(@raunakab): migrate to opal Button once HoverIconButtonProps typing is resolved */}
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
{children}
</Button>
</div>
<ConfirmationModalLayout
icon={SvgUnplug}
title={`Disconnect ${disconnectTarget.label}`}
description="This will remove the stored credentials for this provider."
onClose={onClose}
submit={
<OpalButton
variant="danger"
onClick={onDisconnect}
disabled={
needsReplacement && hasReplacements && !replacementProviderId
}
>
Disconnect
</OpalButton>
}
>
{needsReplacement ? (
hasReplacements ? (
<Section alignItems="start">
<Text as="p" text03>
<b>{disconnectTarget.label}</b> is currently the active{" "}
{categoryLabel}. Search history will be preserved.
</Text>
<Section alignItems="start" gap={0.25}>
<Text as="p" secondaryBody text03>
Set New Default
</Text>
<InputSelect
value={replacementProviderId ?? undefined}
onValueChange={(v) => onReplacementChange(v)}
>
<InputSelect.Trigger placeholder="Select a replacement provider" />
<InputSelect.Content>
{replacementOptions.map((p) => (
<InputSelect.Item key={p.id} value={String(p.id)}>
{getLabel(p)}
</InputSelect.Item>
))}
<InputSelect.Separator />
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
<span>
<b>No Default</b>
<span className="text-text-03"> ({disableLabel})</span>
</span>
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
</Section>
) : (
<>
<Text as="p" text03>
<b>{disconnectTarget.label}</b> is currently the active{" "}
{categoryLabel}.
</Text>
<Text as="p" text03>
Connect another provider to continue using {featureLabel}.
</Text>
</>
)
) : (
<>
<Text as="p" text03>
{isSearch ? "Web search" : "Web crawling"} will no longer be routed
through <b>{disconnectTarget.label}</b>.
</Text>
<Text as="p" text03>
Search history will be preserved.
</Text>
</>
)}
</ConfirmationModalLayout>
);
}
@@ -105,6 +226,11 @@ export default function Page() {
WebProviderModalReducer,
initialWebProviderModalState
);
const [disconnectTarget, setDisconnectTarget] =
useState<DisconnectTargetState | null>(null);
const [replacementProviderId, setReplacementProviderId] = useState<
string | null
>(null);
const [contentModal, dispatchContentModal] = useReducer(
WebProviderModalReducer,
initialWebProviderModalState
@@ -113,8 +239,6 @@ export default function Page() {
const [contentActivationError, setContentActivationError] = useState<
string | null
>(null);
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
const {
data: searchProvidersData,
error: searchProvidersError,
@@ -833,6 +957,67 @@ export default function Page() {
});
};
const handleDisconnectProvider = async () => {
if (!disconnectTarget) return;
const { id, category } = disconnectTarget;
try {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
const repId = Number(replacementProviderId);
const activateEndpoint =
category === "search"
? `/api/admin/web-search/search-providers/${repId}/activate`
: `/api/admin/web-search/content-providers/${repId}/activate`;
const activateResp = await fetch(activateEndpoint, {
method: "POST",
headers: { "Content-Type": "application/json" },
});
if (!activateResp.ok) {
const errorBody = await activateResp.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to activate replacement provider."
);
}
}
const response = await fetch(
`/api/admin/web-search/${category}-providers/${id}`,
{ method: "DELETE" }
);
if (!response.ok) {
const errorBody = await response.json().catch((parseErr) => {
console.error("Failed to parse disconnect error response:", parseErr);
return {};
});
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: "Failed to disconnect provider."
);
}
toast.success(`${disconnectTarget.label} disconnected`);
await mutateSearchProviders();
await mutateContentProviders();
} catch (error) {
console.error("Failed to disconnect web search provider:", error);
const message =
error instanceof Error ? error.message : "Unexpected error occurred.";
if (category === "search") {
setActivationError(message);
} else {
setContentActivationError(message);
}
} finally {
setDisconnectTarget(null);
setReplacementProviderId(null);
}
};
return (
<>
<SettingsLayouts.Root>
@@ -894,149 +1079,79 @@ export default function Page() {
provider
);
const isActive = provider?.is_active ?? false;
const isHighlighted = isActive;
const providerId = provider?.id;
const canOpenModal =
isBuiltInSearchProviderType(providerType);
const buttonState = (() => {
if (!provider || !isConfigured) {
return {
label: "Connect",
disabled: false,
icon: "arrow" as const,
onClick: canOpenModal
const status: "disconnected" | "connected" | "selected" =
!isConfigured
? "disconnected"
: isActive
? "selected"
: "connected";
return (
<Select
key={`${key}-${providerType}`}
icon={() =>
logoSrc ? (
<Image
src={logoSrc}
alt={`${label} logo`}
width={16}
height={16}
/>
) : (
<SvgGlobe size={16} />
)
}
title={label}
description={subtitle}
status={status}
onConnect={
canOpenModal
? () => {
openSearchModal(providerType, provider);
setActivationError(null);
}
: undefined,
};
}
if (isActive) {
return {
label: "Current Default",
disabled: false,
icon: "check" as const,
onClick: providerId
: undefined
}
onSelect={
providerId
? () => {
void handleActivateSearchProvider(providerId);
}
: undefined
}
onDeselect={
providerId
? () => {
void handleDeactivateSearchProvider(providerId);
}
: undefined,
};
}
return {
label: "Set as Default",
disabled: false,
icon: "arrow-circle" as const,
onClick: providerId
? () => {
void handleActivateSearchProvider(providerId);
}
: undefined,
};
})();
const buttonKey = `search-${key}-${providerType}`;
const isButtonHovered = hoveredButtonKey === buttonKey;
const isCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleCardClick = () => {
if (isCardClickable) {
buttonState.onClick?.();
}
};
return (
<div
key={`${key}-${providerType}`}
onClick={isCardClickable ? handleCardClick : undefined}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
isHighlighted
? "border-action-link-05"
: "border-border-01",
isCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 px-2 py-1">
{renderLogo({
logoSrc,
alt: `${label} logo`,
size: 16,
isHighlighted,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
</div>
<div className="flex items-center justify-end gap-2">
{isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={() => {
if (!canOpenModal) return;
: undefined
}
onEdit={
isConfigured && canOpenModal
? () => {
openSearchModal(
providerType as WebSearchProviderType,
provider
);
}}
aria-label={`Edit ${label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isButtonHovered}
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<Disabled
disabled={
buttonState.disabled || !buttonState.onClick
}
>
<OpalButton
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</OpalButton>
</Disabled>
)}
</div>
</div>
: undefined
}
onDisconnect={
isConfigured && provider && provider.id > 0
? () =>
setDisconnectTarget({
id: provider.id,
label,
category: "search",
providerType,
})
: undefined
}
/>
);
}
)}
@@ -1076,161 +1191,81 @@ export default function Page() {
const isCurrentCrawler =
provider.provider_type === currentContentProviderType;
const buttonState = (() => {
if (!isConfigured) {
return {
label: "Connect",
icon: "arrow" as const,
disabled: false,
onClick: () => {
openContentModal(provider.provider_type, provider);
setContentActivationError(null);
},
};
}
const status: "disconnected" | "connected" | "selected" =
!isConfigured
? "disconnected"
: isCurrentCrawler
? "selected"
: "connected";
if (isCurrentCrawler) {
return {
label: "Current Crawler",
icon: "check" as const,
disabled: false,
onClick: () => {
void handleDeactivateContentProvider(
providerId,
provider.provider_type
);
},
};
}
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
return {
label: "Set as Default",
icon: "arrow-circle" as const,
disabled: !canActivate,
onClick: canActivate
? () => {
void handleActivateContentProvider(provider);
}
: undefined,
};
})();
const contentButtonKey = `content-${provider.provider_type}-${provider.id}`;
const isContentButtonHovered =
hoveredButtonKey === contentButtonKey;
const isContentCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleContentCardClick = () => {
if (isContentCardClickable) {
buttonState.onClick?.();
}
};
const contentLogoSrc =
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.logoSrc;
return (
<div
<Select
key={`${provider.provider_type}-${provider.id}`}
onClick={
isContentCardClickable
? handleContentCardClick
icon={() =>
contentLogoSrc ? (
<Image
src={contentLogoSrc}
alt={`${label} logo`}
width={16}
height={16}
/>
) : provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo size={16} />
) : (
<SvgGlobe size={16} />
)
}
title={label}
description={subtitle}
status={status}
selectedLabel="Current Crawler"
onConnect={() => {
openContentModal(provider.provider_type, provider);
setContentActivationError(null);
}}
onSelect={
canActivate
? () => {
void handleActivateContentProvider(provider);
}
: undefined
}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
isCurrentCrawler
? "border-action-link-05"
: "border-border-01",
isContentCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 px-2 py-1">
{renderLogo({
logoSrc:
CONTENT_PROVIDER_DETAILS[provider.provider_type]
?.logoSrc,
alt: `${label} logo`,
fallback:
provider.provider_type === "onyx_web_crawler" ? (
<SvgOnyxLogo size={16} />
) : undefined,
size: 16,
isHighlighted: isCurrentCrawler,
})}
<Content
title={label}
description={subtitle}
sizePreset="main-ui"
variant="section"
/>
</div>
<div className="flex items-center justify-end gap-2">
{provider.provider_type !== "onyx_web_crawler" &&
isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={() => {
openContentModal(
provider.provider_type,
provider
);
}}
aria-label={`Edit ${label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isContentButtonHovered}
onMouseEnter={() =>
setHoveredButtonKey(contentButtonKey)
onDeselect={() => {
void handleDeactivateContentProvider(
providerId,
provider.provider_type
);
}}
onEdit={
provider.provider_type !== "onyx_web_crawler" &&
isConfigured
? () => {
openContentModal(provider.provider_type, provider);
}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<Disabled
disabled={
buttonState.disabled || !buttonState.onClick
}
>
<OpalButton
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</OpalButton>
</Disabled>
)}
</div>
</div>
: undefined
}
onDisconnect={
provider.provider_type !== "onyx_web_crawler" &&
isConfigured &&
provider.id > 0
? () =>
setDisconnectTarget({
id: provider.id,
label,
category: "content",
providerType: provider.provider_type,
})
: undefined
}
/>
);
})}
</div>
@@ -1238,6 +1273,21 @@ export default function Page() {
</SettingsLayouts.Body>
</SettingsLayouts.Root>
{disconnectTarget && (
<WebSearchDisconnectModal
disconnectTarget={disconnectTarget}
searchProviders={searchProviders}
contentProviders={combinedContentProviders}
replacementProviderId={replacementProviderId}
onReplacementChange={setReplacementProviderId}
onClose={() => {
setDisconnectTarget(null);
setReplacementProviderId(null);
}}
onDisconnect={() => void handleDisconnectProvider()}
/>
)}
<WebProviderSetupModal
isOpen={selectedProviderType !== null}
onClose={() => {

View File

@@ -9,7 +9,7 @@ import {
TableCell,
TableHeader,
} from "@/components/ui/table";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { Callout } from "@/components/ui/callout";
import { CCPairFullInfo } from "./types";
import { IndexAttemptSnapshot } from "@/lib/types";
@@ -153,17 +153,11 @@ export function IndexAttemptsTable({
</div>
</TableCell>
<TableCell>
{indexAttempt.status === "success" && (
<Text className="flex flex-wrap whitespace-normal">
{"-"}
</Text>
)}
{indexAttempt.status === "success" && <Text as="p">-</Text>}
{indexAttempt.status === "failed" &&
indexAttempt.error_msg && (
<Text className="flex flex-wrap whitespace-normal">
{indexAttempt.error_msg}
</Text>
<Text as="p">{indexAttempt.error_msg}</Text>
)}
</TableCell>
<td className="w-0 p-0">

View File

@@ -11,9 +11,10 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Button } from "@opal/components";
import { Button, Text } from "@opal/components";
import { Card } from "@/components/ui/card";
import Text from "@/components/ui/text";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import { Spinner } from "@/components/Spinner";
import { SvgDownloadCloud } from "@opal/icons";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
@@ -75,11 +76,12 @@ function Main() {
<>
{isDownloading && <Spinner />}
<div className="mb-8">
<Text className="mb-3">
<b>Debug Logs</b> provide detailed information about system operations
and events. You can download logs for each category to analyze system
behavior or troubleshoot issues.
<Text as="p">
{markdown(
"**Debug Logs** provide detailed information about system operations and events. You can download logs for each category to analyze system behavior or troubleshoot issues."
)}
</Text>
<Spacer rem={0.75} />
{categories.length > 0 && (
<Card className="mt-4">

View File

@@ -10,7 +10,9 @@ import {
TableBody,
TableCell,
} from "@/components/ui/table";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import Title from "@/components/ui/title";
import Separator from "@/refresh-components/Separator";
import { DocumentSetSummary } from "@/lib/types";
@@ -393,11 +395,12 @@ function Main() {
return (
<div className="mb-8">
<Text className="mb-3">
<b>Document Sets</b> allow you to group logically connected documents
into a single bundle. These can then be used as a filter when performing
searches to control the scope of information Onyx searches over.
<Text as="p">
{markdown(
"**Document Sets** allow you to group logically connected documents into a single bundle. These can then be used as a filter when performing searches to control the scope of information Onyx searches over."
)}
</Text>
<Spacer rem={0.75} />
<div className="mb-3"></div>

View File

@@ -1,6 +1,8 @@
"use client";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import Title from "@/components/ui/title";
import {
CloudEmbeddingProvider,
@@ -99,10 +101,12 @@ export default function CloudEmbeddingPage({
<Title className="mt-8">
Here are some cloud-based models to choose from.
</Title>
<Text className="mb-4">
These models require API keys and run in the clouds of the respective
providers.
<Text as="p">
{
"These models require API keys and run in the clouds of the respective providers."
}
</Text>
<Spacer rem={1} />
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
{providers.map((provider) => (
@@ -156,18 +160,11 @@ export default function CloudEmbeddingPage({
</div>
))}
<Text className="mt-6">
Alternatively, you can use a self-hosted model using the LiteLLM
proxy. This allows you to leverage various LLM providers through a
unified interface that you control.{" "}
<a
href="https://docs.litellm.ai/"
target="_blank"
rel="noopener noreferrer"
className="text-blue-500 hover:underline"
>
Learn more about LiteLLM
</a>
<Spacer rem={1.5} />
<Text as="p">
{markdown(
"Alternatively, you can use a self-hosted model using the LiteLLM proxy. This allows you to leverage various LLM providers through a unified interface that you control. [Learn more about LiteLLM](https://docs.litellm.ai/)"
)}
</Text>
<div key={LITELLM_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
@@ -214,20 +211,25 @@ export default function CloudEmbeddingPage({
{!liteLLMProvider && (
<CardSection className="mt-2 w-full max-w-4xl bg-background-50 border border-background-200">
<div className="p-4">
<Text className="text-lg font-semibold mb-2">
<Text as="p" font="heading-h3">
API URL Required
</Text>
<Text className="text-sm text-text-600 mb-4">
Before you can add models, you need to provide an API URL
for your LiteLLM proxy. Click the &quot;Provide API
URL&quot; button above to set up your LiteLLM configuration.
<Spacer rem={0.5} />
<Text as="p">
{
'Before you can add models, you need to provide an API URL for your LiteLLM proxy. Click the "Provide API URL" button above to set up your LiteLLM configuration.'
}
</Text>
<Spacer rem={1} />
<div className="flex items-center">
<FiInfo className="text-blue-500 mr-2" size={18} />
<Text className="text-sm text-blue-500">
Once configured, you&apos;ll be able to add and manage
your LiteLLM models here.
</Text>
<span className="text-blue-500">
<Text as="p">
{
"Once configured, you'll be able to add and manage your LiteLLM models here."
}
</Text>
</span>
</div>
</div>
</CardSection>
@@ -281,9 +283,11 @@ export default function CloudEmbeddingPage({
</div>
</div>
<Text className="mt-6">
You can also use Azure OpenAI models for embeddings. Azure requires
separate configuration for each model.
<Spacer rem={1.5} />
<Text as="p">
{
"You can also use Azure OpenAI models for embeddings. Azure requires separate configuration for each model."
}
</Text>
<div key={AZURE_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
@@ -319,18 +323,22 @@ export default function CloudEmbeddingPage({
</button>
<div className="mt-2 w-full max-w-4xl">
<CardSection className="p-4 border border-background-200 rounded-lg shadow-sm">
<Text className="text-base font-medium mb-2">
<Text as="p" font="main-ui-action">
Configure Azure OpenAI for Embeddings
</Text>
<Text className="text-sm text-text-600 mb-3">
Click &quot;Configure Azure OpenAI&quot; to set up Azure
OpenAI for embeddings.
<Spacer rem={0.5} />
<Text as="p">
{
'Click "Configure Azure OpenAI" to set up Azure OpenAI for embeddings.'
}
</Text>
<div className="flex items-center text-sm text-text-700">
<Spacer rem={0.75} />
<div className="flex items-center">
<FiInfo className="text-neutral-400 mr-2" size={16} />
<Text>
You&apos;ll need: API version, base URL, API key, model
name, and deployment name.
<Text as="p">
{
"You'll need: API version, base URL, API key, model name, and deployment name."
}
</Text>
</div>
</CardSection>
@@ -339,9 +347,10 @@ export default function CloudEmbeddingPage({
) : (
<>
<div className="mb-6 w-full">
<Text className="text-lg font-semibold mb-3">
<Text as="p" font="heading-h3">
Current Azure Configuration
</Text>
<Spacer rem={0.75} />
{azureProviderDetails ? (
<CardSection className="bg-white shadow-sm border border-background-200 rounded-lg">

View File

@@ -1,7 +1,9 @@
"use client";
import Button from "@/refresh-components/buttons/Button";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import Title from "@/components/ui/title";
import { ModelSelector } from "../../../../components/embedding/ModelSelector";
import {
@@ -25,41 +27,28 @@ export default function OpenEmbeddingPage({
<Title className="mt-8">
Here are some locally-hosted models to choose from.
</Title>
<Text className="mb-4">
These models can be used without any API keys, and can leverage a GPU
for faster inference.
<Text as="p">
{
"These models can be used without any API keys, and can leverage a GPU for faster inference."
}
</Text>
<Spacer rem={1} />
<ModelSelector
modelOptions={AVAILABLE_MODELS}
setSelectedModel={onSelectOpenSource}
currentEmbeddingModel={selectedProvider}
/>
<Text className="mt-6">
Alternatively, (if you know what you&apos;re doing) you can specify a{" "}
<a
target="_blank"
href="https://www.sbert.net/"
className="text-link"
rel="noreferrer"
>
SentenceTransformers
</a>
-compatible model of your choice below. The rough list of supported
models can be found{" "}
<a
target="_blank"
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
className="text-link"
rel="noreferrer"
>
here
</a>
.
<br />
<b>NOTE:</b> not all models listed will work with Onyx, since some have
unique interfaces or special requirements. If in doubt, reach out to the
Onyx team.
<Spacer rem={1.5} />
<Text as="p">
{markdown(
"Alternatively, (if you know what you're doing) you can specify a [SentenceTransformers](https://www.sbert.net/)-compatible model of your choice below. The rough list of supported models can be found [here](https://huggingface.co/models?library=sentence-transformers&sort=trending)."
)}
</Text>
<Text as="p">
{markdown(
"**NOTE:** not all models listed will work with Onyx, since some have unique interfaces or special requirements. If in doubt, reach out to the Onyx team."
)}
</Text>
{!configureModel && (
// TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved

View File

@@ -5,7 +5,9 @@ import { SearchAndFilterControls } from "./SearchAndFilterControls";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Link from "next/link";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import { useConnectorIndexingStatusWithPagination } from "@/lib/hooks";
import { useToastFromQuery } from "@/hooks/useToast";
import { Button } from "@opal/components";
@@ -185,13 +187,14 @@ function Main() {
<ConnectorStaggeredSkeleton rowCount={8} standalone={true} />
</div>
) : !ccPairsIndexingStatuses || ccPairsIndexingStatuses.length === 0 ? (
<Text className="mt-12">
It looks like you don&apos;t have any connectors setup yet. Visit the{" "}
<Link className="text-link" href="/admin/add-connector">
Add Connector
</Link>{" "}
page to get started!
</Text>
<div>
<Spacer rem={3} />
<Text as="p">
{markdown(
"It looks like you don't have any connectors setup yet. Visit the [Add Connector](/admin/add-connector) page to get started!"
)}
</Text>
</div>
) : (
<CCPairIndexingStatusTable
ccPairsIndexingStatuses={ccPairsIndexingStatuses}

View File

@@ -16,7 +16,8 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import { TableHeader } from "@/components/ui/table";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Spacer from "@/refresh-components/Spacer";
type TokenRateLimitTableArgs = {
tokenRateLimits: TokenRateLimitDisplay[];
@@ -68,11 +69,15 @@ export const TokenRateLimitTable = ({
<div className="w-full">
{!hideHeading && title && <Title>{title}</Title>}
{!hideHeading && description && (
<Text className="my-2">{description}</Text>
<>
<Spacer rem={0.5} />
<Text as="p">{description}</Text>
<Spacer rem={0.5} />
</>
)}
<Text className={`${!hideHeading && "my-8"}`}>
No token rate limits set!
</Text>
{!hideHeading && <Spacer rem={2} />}
<Text as="p">No token rate limits set!</Text>
{!hideHeading && <Spacer rem={2} />}
</div>
);
}
@@ -81,7 +86,11 @@ export const TokenRateLimitTable = ({
<div className="w-full">
{!hideHeading && title && <Title>{title}</Title>}
{!hideHeading && description && (
<Text className="my-2">{description}</Text>
<>
<Spacer rem={0.5} />
<Text as="p">{description}</Text>
<Spacer rem={0.5} />
</>
)}
<Table
className={`overflow-visible ${
@@ -188,7 +197,7 @@ export const GenericTokenRateLimitTable = ({
}
if (!isLoading && error) {
return <Text>Failed to load token rate limits</Text>;
return <Text as="p">Failed to load token rate limits</Text>;
}
let processedData = data;

View File

@@ -2,7 +2,7 @@
import SimpleTabs from "@/refresh-components/SimpleTabs";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { useState } from "react";
import {
insertGlobalTokenRateLimit,
@@ -104,14 +104,14 @@ function Main() {
return (
<Section alignItems="stretch" justifyContent="start" height="auto">
<Text>
<Text as="p">
Token rate limits enable you control how many tokens can be spent in a
given time period. With token rate limits, you can:
</Text>
<ul className="list-disc ml-4">
<li>
<Text>
<Text as="p">
Set a global rate limit to control your team&apos;s overall token
spend.
</Text>
@@ -119,13 +119,13 @@ function Main() {
{isPaidEnterpriseFeaturesEnabled && (
<>
<li>
<Text>
<Text as="p">
Set rate limits for users to ensure that no single user can
spend too many tokens.
</Text>
</li>
<li>
<Text>
<Text as="p">
Set rate limits for user groups to control token spend for your
teams.
</Text>
@@ -133,7 +133,7 @@ function Main() {
</>
)}
<li>
<Text>Enable and disable rate limits on the fly.</Text>
<Text as="p">Enable and disable rate limits on the fly.</Text>
</li>
</ul>

View File

@@ -3,7 +3,9 @@ import React, { useState } from "react";
import { forgotPassword } from "./utils";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import Title from "@/components/ui/title";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import Link from "next/link";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
@@ -73,12 +75,11 @@ const ForgotPasswordPage: React.FC = () => {
</Form>
)}
</Formik>
<Spacer rem={1} />
<div className="flex">
<Text className="mt-4 mx-auto">
<Link href="/auth/login" className="text-link font-medium">
Back to Login
</Link>
</Text>
<div className="mx-auto">
<Text as="p">{markdown("[Back to Login](/auth/login)")}</Text>
</div>
</div>
</div>
</AuthFlowContainer>

View File

@@ -3,7 +3,9 @@ import React, { useState, useEffect } from "react";
import { resetPassword } from "../forgot-password/utils";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import Title from "@/components/ui/title";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import Link from "next/link";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
@@ -109,12 +111,11 @@ const ResetPasswordPage: React.FC = () => {
</Form>
)}
</Formik>
<Spacer rem={1} />
<div className="flex">
<Text className="mt-4 mx-auto">
<Link href="/auth/login" className="text-link font-medium">
Back to Login
</Link>
</Text>
<div className="mx-auto">
<Text as="p">{markdown("[Back to Login](/auth/login)")}</Text>
</div>
</div>
</div>
</AuthFlowContainer>

View File

@@ -2,7 +2,8 @@
import { useSearchParams } from "next/navigation";
import { useCallback, useEffect, useState } from "react";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Spacer from "@/refresh-components/Spacer";
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
import { User } from "@/lib/types";
import Logo from "@/refresh-components/Logo";
@@ -65,17 +66,22 @@ export default function Verify({ user }: VerifyProps) {
<div className="min-h-screen flex flex-col items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<Logo folded size={64} className="mx-auto w-fit animate-pulse" />
{!error ? (
<Text className="mt-2">Verifying your email...</Text>
<>
<Spacer rem={0.5} />
<Text as="p">Verifying your email...</Text>
</>
) : (
<div>
<Text className="mt-2">{error}</Text>
<Spacer rem={0.5} />
<Text as="p">{error}</Text>
{user && (
<div className="text-center">
<RequestNewVerificationEmail email={user.email}>
<Text className="mt-2 text-link">
{/* TODO(@raunakab): migrate to @opal/components Text */}
<p className="text-sm mt-2 text-link">
Get new verification email
</Text>
</p>
</RequestNewVerificationEmail>
</div>
)}

View File

@@ -4,11 +4,11 @@ import {
getCurrentUserSS,
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { User } from "@/lib/types";
import Text from "@/components/ui/text";
import { RequestNewVerificationEmail } from "./RequestNewVerificationEmail";
import Logo from "@/refresh-components/Logo";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
export default async function Page() {
// catch cases where the backend is completely unreachable here
@@ -35,22 +35,21 @@ export default async function Page() {
return (
<main>
<div className="min-h-screen flex flex-col items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div className="min-h-screen flex flex-col items-center justify-center py-12 px-4 sm:px-6 lg:px-8 gap-4">
<Logo folded size={64} className="mx-auto w-fit" />
<div className="flex">
<Text className="text-center font-medium text-lg mt-6 w-108">
Hey <i>{currentUser.email}</i> - it looks like you haven&apos;t
verified your email yet.
<br />
Check your inbox for an email from us to get started!
<br />
<br />
If you don&apos;t see anything, click{" "}
<RequestNewVerificationEmail email={currentUser.email}>
here
</RequestNewVerificationEmail>{" "}
to request a new email.
<div className="flex flex-col gap-2">
<Text as="span">
{markdown(
`Hey, *${currentUser.email}*, it looks like you haven't verified your email yet.\nCheck your inbox for an email from us to get started!`
)}
</Text>
<div className="flex flex-row items-center gap-1">
<Text as="span">If you don't see anything, click</Text>
<RequestNewVerificationEmail email={currentUser.email}>
<Text as="span">here</Text>
</RequestNewVerificationEmail>
<Text as="span">to request a new email.</Text>
</div>
</div>
</div>
</main>

View File

@@ -19,6 +19,10 @@
background-color: var(--background-neutral-00);
border: 1px solid var(--status-error-05);
}
.input-error:focus:not(:active),
.input-error:focus-within:not(:active) {
box-shadow: inset 0px 0px 0px 2px var(--background-tint-04);
}
.input-disabled {
background-color: var(--background-neutral-03);

View File

@@ -3,9 +3,9 @@
import { Label, SubLabel } from "@/components/Field";
import { toast } from "@/hooks/useToast";
import { SettingsContext } from "@/providers/SettingsProvider";
import { Button } from "@opal/components";
import { Button, Text } from "@opal/components";
import { markdown } from "@opal/utils";
import { Callout } from "@/components/ui/callout";
import Text from "@/components/ui/text";
import { useContext, useState } from "react";
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
import Spacer from "@/refresh-components/Spacer";
@@ -54,17 +54,17 @@ export function CustomAnalyticsUpdateForm() {
>
<div className="mb-4">
<Label>Script</Label>
<Text className="mb-3">
<Text as="p">
Specify the Javascript that should run on page load in order to
initialize your custom tracking/analytics.
</Text>
<Text className="mb-2">
Do not include the{" "}
<span className="font-mono">&lt;script&gt;&lt;/script&gt;</span>{" "}
tags. If you upload a script below but you are not recieving any
events in your analytics platform, try removing all extra whitespace
before each line of JavaScript.
<Spacer rem={0.75} />
<Text as="p">
{markdown(
"Do not include the `<script></script>` tags. If you upload a script below but you are not receiving any events in your analytics platform, try removing all extra whitespace before each line of JavaScript."
)}
</Text>
<Spacer rem={0.5} />
<InputTextArea
value={newCustomAnalyticsScript}
onChange={(event) =>

View File

@@ -2,7 +2,8 @@ import * as SettingsLayouts from "@/layouts/settings-layouts";
import { CUSTOM_ANALYTICS_ENABLED } from "@/lib/constants";
import { Callout } from "@/components/ui/callout";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Spacer from "@/refresh-components/Spacer";
import { CustomAnalyticsUpdateForm } from "./CustomAnalyticsUpdateForm";
const route = ADMIN_ROUTES.CUSTOM_ANALYTICS;
@@ -24,11 +25,12 @@ function Main() {
return (
<div>
<Text className="mb-8">
This allows you to bring your own analytics tool to Onyx! Copy the Web
snippet from your analytics provider into the box below, and we&apos;ll
start sending usage events.
<Text as="p">
{
"This allows you to bring your own analytics tool to Onyx! Copy the Web snippet from your analytics provider into the box below, and we'll start sending usage events."
}
</Text>
<Spacer rem={2} />
<CustomAnalyticsUpdateForm />
</div>

View File

@@ -1,9 +1,10 @@
"use client";
import { use } from "react";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Title from "@/components/ui/title";
import Separator from "@/refresh-components/Separator";
import Spacer from "@/refresh-components/Spacer";
import { ChatSessionSnapshot, MessageSnapshot } from "../../usage/types";
import { FiBook } from "react-icons/fi";
import { timestampToReadableDate } from "@/lib/dateUtils";
@@ -21,13 +22,13 @@ function MessageDisplay({ message }: { message: MessageSnapshot }) {
<p className="text-xs font-bold mb-1">
{message.message_type === "user" ? "User" : "AI"}
</p>
<Text>{message.message}</Text>
<Text as="p">{message.message}</Text>
{message.documents.length > 0 && (
<div className="flex flex-col gap-y-2 mt-2">
<p className="font-bold text-xs">Reference Documents</p>
{message.documents.slice(0, 5).map((document) => {
return (
<Text className="flex" key={document.document_id}>
<div className="text-sm flex" key={document.document_id}>
<FiBook
className={
"my-auto mr-1" + (document.link ? " text-link" : " ")
@@ -45,7 +46,7 @@ function MessageDisplay({ message }: { message: MessageSnapshot }) {
) : (
document.semantic_identifier
)}
</Text>
</div>
);
})}
</div>
@@ -53,7 +54,7 @@ function MessageDisplay({ message }: { message: MessageSnapshot }) {
{message.feedback_type && (
<div className="mt-2">
<p className="font-bold text-xs">Feedback</p>
{message.feedback_text && <Text>{message.feedback_text}</Text>}
{message.feedback_text && <Text as="p">{message.feedback_text}</Text>}
<div className="mt-1">
<FeedbackBadge feedback={message.feedback_type} />
</div>
@@ -95,14 +96,19 @@ export default function QueryPage(props: { params: Promise<{ id: string }> }) {
<CardSection className="mt-4">
<Title>Chat Session Details</Title>
<Text className="flex flex-wrap whitespace-normal mt-1 text-sm">
{chatSessionSnapshot.assistant_name}
</Text>
<Text className="flex flex-wrap whitespace-normal mt-1 text-xs">
{chatSessionSnapshot.user_email &&
`${chatSessionSnapshot.user_email}, `}
{timestampToReadableDate(chatSessionSnapshot.time_created)},{" "}
{chatSessionSnapshot.flow_type}
<Spacer rem={0.25} />
{chatSessionSnapshot.assistant_name && (
<Text as="p">{chatSessionSnapshot.assistant_name}</Text>
)}
<Spacer rem={0.25} />
<Text as="p">
{`${
chatSessionSnapshot.user_email
? `${chatSessionSnapshot.user_email}, `
: ""
}${timestampToReadableDate(chatSessionSnapshot.time_created)}, ${
chatSessionSnapshot.flow_type
}`}
</Text>
<Separator />

View File

@@ -1,6 +1,6 @@
import { ThreeDotsLoader } from "@/components/Loading";
import { getDatesList, useQueryAnalytics } from "../lib";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Title from "@/components/ui/title";
import { DateRangePickerValue } from "@/components/dateRangeSelectors/AdminDateRangeSelector";
@@ -68,7 +68,7 @@ export function FeedbackChart({
return (
<CardSection className="mt-8">
<Title>Feedback</Title>
<Text>Thumbs Up / Thumbs Down over time</Text>
<Text as="p">Thumbs Up / Thumbs Down over time</Text>
{chart}
</CardSection>
);

View File

@@ -1,7 +1,7 @@
import { ThreeDotsLoader } from "@/components/Loading";
import { getDatesList, useOnyxBotAnalytics } from "../lib";
import { DateRangePickerValue } from "@/components/dateRangeSelectors/AdminDateRangeSelector";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Title from "@/components/ui/title";
import CardSection from "@/components/admin/CardSection";
import { AreaChartDisplay } from "@/components/ui/areaChart";
@@ -69,7 +69,7 @@ export function OnyxBotChart({
return (
<CardSection className="mt-8">
<Title>Slack Channel</Title>
<Text>Total Queries vs Auto Resolved</Text>
<Text as="p">Total Queries vs Auto Resolved</Text>
{chart}
</CardSection>
);

View File

@@ -6,7 +6,7 @@ import {
usePersonaUniqueUsers,
} from "../lib";
import { DateRangePickerValue } from "@/components/dateRangeSelectors/AdminDateRangeSelector";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Title from "@/components/ui/title";
import CardSection from "@/components/admin/CardSection";
import { AreaChartDisplay } from "@/components/ui/areaChart";
@@ -180,7 +180,9 @@ export function PersonaMessagesChart({
<CardSection className="mt-8">
<Title>Agent Analytics</Title>
<div className="flex flex-col gap-4">
<Text>Messages and unique users per day for the selected agent</Text>
<Text as="p">
Messages and unique users per day for the selected agent
</Text>
<div className="flex items-center gap-4">
<Select
value={selectedPersonaId?.toString() ?? ""}

View File

@@ -5,7 +5,7 @@ import { getDatesList, useQueryAnalytics, useUserAnalytics } from "../lib";
import { ThreeDotsLoader } from "@/components/Loading";
import { AreaChartDisplay } from "@/components/ui/areaChart";
import Title from "@/components/ui/title";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import CardSection from "@/components/admin/CardSection";
export function QueryPerformanceChart({
@@ -98,7 +98,7 @@ export function QueryPerformanceChart({
return (
<CardSection className="mt-8">
<Title>Usage</Title>
<Text>Usage over time</Text>
<Text as="p">Usage over time</Text>
{chart}
</CardSection>
);

View File

@@ -12,8 +12,9 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Title from "@/components/ui/title";
import Spacer from "@/refresh-components/Spacer";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { Disabled } from "@opal/core";
@@ -98,9 +99,8 @@ function GenerateReportInput({
return (
<div className="mb-8">
<Title className="mb-2">Generate Usage Reports</Title>
<Text className="mb-8">
Generate usage statistics for users in the workspace.
</Text>
<Text as="p">Generate usage statistics for users in the workspace.</Text>
<Spacer rem={2} />
<div className="grid gap-2 mb-3">
<Popover>
<Popover.Trigger asChild>
@@ -412,9 +412,9 @@ export default function UsageReports() {
isWaitingForReport={isWaitingForReport}
/>
{timeoutMessage && (
<div className="mb-4 p-4 bg-amber-50 dark:bg-amber-950/20 border border-amber-200 dark:border-amber-800 rounded-regular">
<div className="mb-4 p-4 bg-status-warning-00 border border-status-warning-02 rounded-regular">
<div className="flex items-start gap-2">
<div className="text-amber-600 dark:text-amber-500 mt-0.5">
<div className="text-status-warning-05 mt-0.5">
<svg
className="w-5 h-5"
fill="none"
@@ -430,12 +430,15 @@ export default function UsageReports() {
</svg>
</div>
<div className="flex-1">
<Text className="text-amber-800 dark:text-amber-200 font-medium mb-1">
Report Generation In Progress
</Text>
<Text className="text-amber-700 dark:text-amber-300 text-sm">
{timeoutMessage}
</Text>
<div className="text-status-warning-05">
<Text as="p" font="main-ui-action">
Report Generation In Progress
</Text>
</div>
<Spacer rem={0.25} />
<div className="text-status-warning-05">
<Text as="p">{timeoutMessage}</Text>
</div>
</div>
</div>
</div>

View File

@@ -25,7 +25,9 @@ import { deleteStandardAnswer } from "./lib";
import { FilterDropdown } from "@/components/search/filtering/FilterDropdown";
import { FiTag } from "react-icons/fi";
import { PageSelector } from "@/components/PageSelector";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { markdown } from "@opal/utils";
import Spacer from "@/refresh-components/Spacer";
import { TableHeader } from "@/components/ui/table";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { SvgEdit, SvgTrash } from "@opal/icons";
@@ -316,19 +318,17 @@ const StandardAnswersTable = ({
<div>
{paginatedStandardAnswers.length === 0 && (
<div className="flex justify-center">
<Text>No matching standard answers found...</Text>
<Text as="p">No matching standard answers found...</Text>
</div>
)}
</div>
{paginatedStandardAnswers.length > 0 && (
<>
<div className="mt-4">
<Text>
Ensure that you have added the category to the relevant{" "}
<a className="text-link" href="/admin/bots">
Slack Bot
</a>
.
<Text as="p">
{markdown(
"Ensure that you have added the category to the relevant [Slack Bot](/admin/bots)."
)}
</Text>
</div>
<div className="mt-4 flex justify-center">
@@ -389,14 +389,17 @@ function Main() {
return (
<div className="mb-8">
<Text className="mb-2">
Manage the standard answers for pre-defined questions.
<br />
Note: Currently, only questions asked from Slack can receive standard
answers.
<Text as="p">
{markdown(
"Manage the standard answers for pre-defined questions.\nNote: Currently, only questions asked from Slack can receive standard answers."
)}
</Text>
<Spacer rem={0.5} />
{standardAnswers.length == 0 && (
<Text className="mb-2">Add your first standard answer below!</Text>
<>
<Text as="p">Add your first standard answer below!</Text>
<Spacer rem={0.5} />
</>
)}
<div className="mb-2"></div>

View File

@@ -16,7 +16,7 @@ import { toast } from "@/hooks/useToast";
import CreateCredential from "./actions/CreateCredential";
import { CCPairFullInfo } from "@/app/admin/connector/[ccPairId]/types";
import ModifyCredential from "./actions/ModifyCredential";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import {
buildCCPairInfoUrl,
buildSimilarCredentialInfoURL,
@@ -185,7 +185,7 @@ export default function CredentialSection({
<div className="flex-grow flex flex-col justify-center">
<div className="flex items-center justify-between">
<div>
<Text className="font-medium">
<Text as="p">
{ccPair.credential.name ||
`Credential #${ccPair.credential.id}`}
</Text>

View File

@@ -1,6 +1,6 @@
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import { FaNewspaper, FaTrash } from "react-icons/fa";
import { TextFormField, TypedFileUploadFormField } from "@/components/Field";
@@ -51,7 +51,7 @@ export default function EditCredential({
return (
<div className="flex flex-col gap-y-6">
<Text>
<Text as="p">
Ensure that you update to a credential with the proper permissions!
</Text>

View File

@@ -7,7 +7,8 @@ import { Formik, Form } from "formik";
import * as Yup from "yup";
import { TextFormField, BooleanFormField } from "@/components/Field";
import { Dispatch, SetStateAction } from "react";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Spacer from "@/refresh-components/Spacer";
import Button from "@/refresh-components/buttons/Button";
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
@@ -59,10 +60,12 @@ export function CustomEmbeddingModelForm({
>
{({ isSubmitting, submitForm, errors }) => (
<Form>
<Text className="text-xl text-text-900 font-bold mb-4">
Specify details for your {getFormattedProviderName(embeddingType)}{" "}
Provider&apos;s model
<Text as="p" font="heading-h3">
{`Specify details for your ${getFormattedProviderName(
embeddingType
)} Provider's model`}
</Text>
<Spacer rem={1} />
<TextFormField
name="model_name"
label="Model Name:"

View File

@@ -13,7 +13,8 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import Text from "@/components/ui/text";
import { Text } from "@opal/components";
import Spacer from "@/refresh-components/Spacer";
import Link from "next/link";
import { useState } from "react";
import { FiLink, FiMaximize2, FiTrash } from "react-icons/fi";
@@ -68,15 +69,21 @@ export function FailedReIndexAttempts({
/>
)}
<Text className="text-status-error-05 font-semibold mb-2">
Failed Re-indexing Attempts
</Text>
<Text className="text-status-error-05 mb-4">
The table below shows only the failed re-indexing attempts for existing
connectors. These failures require immediate attention. Once all
connectors have been re-indexed successfully, the new model will be used
for all search queries.
</Text>
<div className="text-status-error-05">
<Text as="p" font="main-ui-action">
Failed Re-indexing Attempts
</Text>
</div>
<Spacer rem={0.5} />
<div className="text-status-error-05">
<Text as="p">
The table below shows only the failed re-indexing attempts for
existing connectors. These failures require immediate attention. Once
all connectors have been re-indexed successfully, the new model will
be used for all search queries.
</Text>
</div>
<Spacer rem={1} />
<div>
<Table>
@@ -114,7 +121,7 @@ export function FailedReIndexAttempts({
<TableCell>
<div>
<Text className="flex flex-wrap whitespace-normal">
<Text as="p">
{reindexingProgress.error_msg || "-"}
</Text>
</div>

View File

@@ -5,6 +5,7 @@
* and support various display sizes.
*/
import React from "react";
import { SvgBifrost } from "@opal/icons";
import { render } from "@tests/setup/test-utils";
import { GithubIcon, GitbookIcon, ConfluenceIcon } from "./icons";
@@ -51,4 +52,15 @@ describe("Logo Icons", () => {
render(<GithubIcon size={100} className="custom-class" />);
}).not.toThrow();
});
test("renders the Bifrost icon with theme-aware colors", () => {
const { container } = render(
<SvgBifrost size={32} className="custom text-red-500 dark:text-black" />
);
const icon = container.querySelector("svg");
expect(icon).toBeInTheDocument();
expect(icon).toHaveClass("custom", "text-[#33C19E]", "dark:text-white");
expect(icon).not.toHaveClass("text-red-500", "dark:text-black");
});
});

View File

@@ -1,11 +0,0 @@
import { cn } from "@/lib/utils";
export default function Text({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) {
return <p className={cn("text-sm", className)}>{children}</p>;
}

View File

@@ -13,6 +13,7 @@ export enum LLMProviderName {
VERTEX_AI = "vertex_ai",
BEDROCK = "bedrock",
LITELLM_PROXY = "litellm_proxy",
BIFROST = "bifrost",
CUSTOM = "custom",
}
@@ -165,6 +166,21 @@ export interface LiteLLMProxyModelResponse {
model_name: string;
}
export interface BifrostFetchParams {
api_base?: string;
api_key?: string;
provider_name?: string;
signal?: AbortSignal;
}
export interface BifrostModelResponse {
name: string;
display_name: string;
max_input_tokens: number | null;
supports_image_input: boolean;
supports_reasoning: boolean;
}
export interface VertexAIFetchParams {
model_configurations?: ModelConfiguration[];
}
@@ -182,5 +198,6 @@ export type FetchModelsParams =
| OllamaFetchParams
| OpenRouterFetchParams
| LiteLLMProxyFetchParams
| BifrostFetchParams
| VertexAIFetchParams
| LMStudioFetchParams;

View File

@@ -1,5 +1,7 @@
"use client";
import type { RichStr } from "@opal/types";
import { resolveStr } from "@opal/components/text/InlineMarkdown";
import Text from "@/refresh-components/texts/Text";
import { SvgXOctagon, SvgAlertCircle } from "@opal/icons";
import { useField, useFormikContext } from "formik";
@@ -12,9 +14,9 @@ interface OrientationLayoutProps {
disabled?: boolean;
nonInteractive?: boolean;
children?: React.ReactNode;
title: string;
title: string | RichStr;
titleSuffix?: string;
description?: string;
description?: string | RichStr;
optional?: boolean;
sizePreset?: "main-content" | "main-ui";
}
@@ -42,7 +44,7 @@ interface OrientationLayoutProps {
* ```
*/
export interface VerticalLayoutProps extends OrientationLayoutProps {
subDescription?: React.ReactNode;
subDescription?: string | RichStr;
}
function VerticalInputLayout({
name,
@@ -70,7 +72,7 @@ function VerticalInputLayout({
{name && <ErrorLayout name={name} />}
{subDescription && (
<Text secondaryBody text03>
{subDescription}
{resolveStr(subDescription)}
</Text>
)}
</Section>

Some files were not shown because too many files have changed in this diff Show More