mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-27 10:32:41 +00:00
Compare commits
17 Commits
multi-mode
...
bo/hook_ui
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9f59aad42 | ||
|
|
b9e84c42a8 | ||
|
|
0a1df52c2f | ||
|
|
306b0d452f | ||
|
|
5fdb34ba8e | ||
|
|
2d066631e3 | ||
|
|
5c84f6c61b | ||
|
|
899179d4b6 | ||
|
|
80d6bafc74 | ||
|
|
2cc325cb0e | ||
|
|
849385b756 | ||
|
|
417b9c12e4 | ||
|
|
30b37d0a77 | ||
|
|
b48be0cd3a | ||
|
|
127fd90424 | ||
|
|
f9c9e55f32 | ||
|
|
5afcf1acea |
64
.greptile/config.json
Normal file
64
.greptile/config.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
57
.greptile/files.json
Normal file
57
.greptile/files.json
Normal file
@@ -0,0 +1,57 @@
|
||||
[
|
||||
{
|
||||
"scope": [],
|
||||
"path": "contributing_guides/best_practices.md",
|
||||
"description": "Best practices for contributing to the codebase"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/AGENTS.md",
|
||||
"description": "Frontend coding standards for the web directory"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/tests/README.md",
|
||||
"description": "Frontend testing guide and conventions"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/CLAUDE.md",
|
||||
"description": "Single source of truth for frontend coding standards"
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"path": "web/lib/opal/README.md",
|
||||
"description": "Opal component library usage guide"
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**"],
|
||||
"path": "backend/tests/README.md",
|
||||
"description": "Backend testing guide covering all 4 test types, fixtures, and conventions"
|
||||
},
|
||||
{
|
||||
"scope": ["backend/onyx/connectors/**"],
|
||||
"path": "backend/onyx/connectors/README.md",
|
||||
"description": "Connector development guide covering design, interfaces, and required changes"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "backend/alembic/README.md",
|
||||
"description": "Migration guidance, including multi-tenant migration behavior"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/helm/charts/onyx/values-lite.yaml",
|
||||
"description": "Lite deployment Helm values and service assumptions"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
|
||||
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
|
||||
}
|
||||
]
|
||||
29
.greptile/rules.md
Normal file
29
.greptile/rules.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Greptile Review Rules
|
||||
|
||||
## Type Annotations
|
||||
|
||||
Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code.
|
||||
|
||||
## Best Practices
|
||||
|
||||
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
|
||||
|
||||
## TODOs
|
||||
|
||||
Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of `TODO(name): ...` or `TODO(1234): ...`
|
||||
|
||||
## Debugging Code
|
||||
|
||||
Remove temporary debugging code before merging to production, especially tenant-specific debugging logs.
|
||||
|
||||
## Hardcoded Booleans
|
||||
|
||||
When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant.
|
||||
|
||||
## Multi-tenant vs Single-tenant
|
||||
|
||||
Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies.
|
||||
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
@@ -0,0 +1,35 @@
|
||||
"""remove voice_provider deleted column
|
||||
|
||||
Revision ID: 1d78c0ca7853
|
||||
Revises: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-26 11:30:53.883127
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d78c0ca7853"
|
||||
down_revision = "a3f8b2c1d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Hard-delete any soft-deleted rows before dropping the column
|
||||
op.execute("DELETE FROM voice_provider WHERE deleted = true")
|
||||
op.drop_column("voice_provider", "deleted")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"voice_provider",
|
||||
sa.Column(
|
||||
"deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
@@ -1,8 +1,19 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -148,3 +159,114 @@ class ChatStateContainer:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,80 +1,19 @@
|
||||
import queue
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Routes packets produced during tool and LLM execution to the right destination.
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
|
||||
Operates in one of two modes determined by whether ``merged_queue`` is supplied:
|
||||
|
||||
**Standalone** (no ``merged_queue``): packets land on ``self.bus``. Used by tests,
|
||||
custom tools, and any caller that reads the emitter directly after execution.
|
||||
|
||||
**Streaming** (``merged_queue`` provided): packets are tagged with ``model_index``
|
||||
and placed as ``(key, packet)`` tuples on the shared queue for the
|
||||
``_run_models`` drain loop to consume and yield downstream.
|
||||
|
||||
Attributes:
|
||||
bus: Fallback queue for standalone mode. Always created so existing callers
|
||||
(tests, eval harnesses, custom-tool scripts) work without modification.
|
||||
|
||||
Args:
|
||||
model_idx: Index embedded in packet placements. Pass ``None`` for single-model
|
||||
runs to preserve the backwards-compatible wire format (``model_index=None``
|
||||
in the packet); pass an integer for each model in a multi-model run.
|
||||
merged_queue: Shared queue owned by the ``_run_models`` drain loop. When set,
|
||||
all ``emit()`` calls route here instead of ``self.bus``.
|
||||
|
||||
Example::
|
||||
|
||||
# Standalone — read from bus after the fact (tests, evals)
|
||||
emitter = Emitter()
|
||||
emitter.emit(packet)
|
||||
result = emitter.bus.get()
|
||||
|
||||
# Streaming — wired into _run_models (production path)
|
||||
emitter = Emitter(model_idx=0, merged_queue=merged_queue)
|
||||
emitter.emit(packet) # places (0, tagged_packet) on merged_queue
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_idx: int | None = None,
|
||||
merged_queue: "queue.Queue | None" = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
# Always created for backwards compatibility (tests, custom_tool, customer scripts, etc.)
|
||||
self.bus: Queue[Packet] = Queue()
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
"""Emit a packet, routing it to the merged queue or the local bus.
|
||||
|
||||
In streaming mode, stamps the packet's placement with ``model_index`` before
|
||||
forwarding so the drain loop can attribute it to the correct model. In
|
||||
standalone mode, places the packet on ``self.bus`` unchanged.
|
||||
|
||||
Args:
|
||||
packet: The packet to emit.
|
||||
"""
|
||||
if self._merged_queue is not None:
|
||||
tagged_placement = Placement(
|
||||
turn_index=packet.placement.turn_index if packet.placement else 0,
|
||||
tab_index=packet.placement.tab_index if packet.placement else 0,
|
||||
sub_turn_index=(
|
||||
packet.placement.sub_turn_index if packet.placement else None
|
||||
),
|
||||
model_index=self._model_idx,
|
||||
)
|
||||
tagged_packet = Packet(placement=tagged_placement, obj=packet.obj)
|
||||
key = self._model_idx if self._model_idx is not None else 0
|
||||
self._merged_queue.put((key, tagged_packet))
|
||||
else:
|
||||
self.bus.put(packet)
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
return Emitter()
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
0
backend/onyx/connectors/canvas/__init__.py
Normal file
0
backend/onyx/connectors/canvas/__init__.py
Normal file
192
backend/onyx/connectors/canvas/client.py
Normal file
192
backend/onyx/connectors/canvas/client.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Requests timeout in seconds.
|
||||
_CANVAS_CALL_TIMEOUT: int = 30
|
||||
_CANVAS_API_VERSION: str = "/api/v1"
|
||||
# Matches the "next" URL in a Canvas Link header, e.g.:
|
||||
# <https://canvas.example.com/api/v1/courses?page=2>; rel="next"
|
||||
# Captures the URL inside the angle brackets.
|
||||
_NEXT_LINK_PATTERN: re.Pattern[str] = re.compile(r'<([^>]+)>;\s*rel="next"')
|
||||
|
||||
|
||||
_STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
|
||||
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
|
||||
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: OnyxErrorCode.BAD_GATEWAY,
|
||||
429: OnyxErrorCode.RATE_LIMITED,
|
||||
}
|
||||
|
||||
|
||||
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
|
||||
"""Map an HTTP status code to the appropriate OnyxErrorCode.
|
||||
|
||||
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
|
||||
mapped to specific error codes; all other codes (unrecognised 4xx
|
||||
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
|
||||
"""
|
||||
if status_code in _STATUS_TO_ERROR_CODE:
|
||||
return _STATUS_TO_ERROR_CODE[status_code]
|
||||
return OnyxErrorCode.BAD_GATEWAY
|
||||
|
||||
|
||||
class CanvasApiClient:
|
||||
def __init__(
|
||||
self,
|
||||
bearer_token: str,
|
||||
canvas_base_url: str,
|
||||
) -> None:
|
||||
parsed_base = urlparse(canvas_base_url)
|
||||
if not parsed_base.hostname:
|
||||
raise ValueError("canvas_base_url must include a valid host")
|
||||
if parsed_base.scheme != "https":
|
||||
raise ValueError("canvas_base_url must use https")
|
||||
|
||||
self._bearer_token = bearer_token
|
||||
self.base_url = (
|
||||
canvas_base_url.rstrip("/").removesuffix(_CANVAS_API_VERSION)
|
||||
+ _CANVAS_API_VERSION
|
||||
)
|
||||
# Hostname is already validated above; reuse parsed_base instead
|
||||
# of re-parsing. Used by _parse_next_link to validate pagination URLs.
|
||||
self._expected_host: str = parsed_base.hostname
|
||||
|
||||
def get(
|
||||
self,
|
||||
endpoint: str = "",
|
||||
params: dict[str, Any] | None = None,
|
||||
full_url: str | None = None,
|
||||
) -> tuple[Any, str | None]:
|
||||
"""Make a GET request to the Canvas API.
|
||||
|
||||
Returns a tuple of (json_body, next_url).
|
||||
next_url is parsed from the Link header and is None if there are no more pages.
|
||||
If full_url is provided, it is used directly (for following pagination links).
|
||||
|
||||
Security note: full_url must only be set to values returned by
|
||||
``_parse_next_link``, which validates the host against the configured
|
||||
Canvas base URL. Passing an arbitrary URL would leak the bearer token.
|
||||
"""
|
||||
# full_url is used when following pagination (Canvas returns the
|
||||
# next-page URL in the Link header). For the first request we build
|
||||
# the URL from the endpoint name instead.
|
||||
url = full_url if full_url else self._build_url(endpoint)
|
||||
headers = self._build_headers()
|
||||
|
||||
response = rl_requests.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params if not full_url else None,
|
||||
timeout=_CANVAS_CALL_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = response.json()
|
||||
except ValueError as e:
|
||||
if response.status_code < 300:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=f"Invalid JSON in Canvas response: {e}",
|
||||
)
|
||||
logger.warning(
|
||||
"Failed to parse JSON from Canvas error response (status=%d): %s",
|
||||
response.status_code,
|
||||
e,
|
||||
)
|
||||
response_json = {}
|
||||
|
||||
if response.status_code >= 400:
|
||||
# Try to extract the most specific error message from the
|
||||
# Canvas response body. Canvas uses three different shapes
|
||||
# depending on the endpoint and error type:
|
||||
default_error: str = response.reason or f"HTTP {response.status_code}"
|
||||
error = default_error
|
||||
if isinstance(response_json, dict):
|
||||
# Shape 1: {"error": {"message": "Not authorized"}}
|
||||
error_field = response_json.get("error")
|
||||
if isinstance(error_field, dict):
|
||||
response_error = error_field.get("message", "")
|
||||
if response_error:
|
||||
error = response_error
|
||||
# Shape 2: {"error": "Invalid access token"}
|
||||
elif isinstance(error_field, str):
|
||||
error = error_field
|
||||
# Shape 3: {"errors": [{"message": "..."}]}
|
||||
# Used for validation errors. Only use as fallback if
|
||||
# we didn't already find a more specific message above.
|
||||
if error == default_error:
|
||||
errors_list = response_json.get("errors")
|
||||
if isinstance(errors_list, list) and errors_list:
|
||||
first_error = errors_list[0]
|
||||
if isinstance(first_error, dict):
|
||||
msg = first_error.get("message", "")
|
||||
if msg:
|
||||
error = msg
|
||||
raise OnyxError(
|
||||
_error_code_for_status(response.status_code),
|
||||
detail=error,
|
||||
status_code_override=response.status_code,
|
||||
)
|
||||
|
||||
next_url = self._parse_next_link(response.headers.get("Link", ""))
|
||||
return response_json, next_url
|
||||
|
||||
def _parse_next_link(self, link_header: str) -> str | None:
|
||||
"""Extract the 'next' URL from a Canvas Link header.
|
||||
|
||||
Only returns URLs whose host matches the configured Canvas base URL
|
||||
to prevent leaking the bearer token to arbitrary hosts.
|
||||
"""
|
||||
expected_host = self._expected_host
|
||||
for match in _NEXT_LINK_PATTERN.finditer(link_header):
|
||||
url = match.group(1)
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.hostname != expected_host:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=(
|
||||
"Canvas pagination returned an unexpected host "
|
||||
f"({parsed_url.hostname}); expected {expected_host}"
|
||||
),
|
||||
)
|
||||
if parsed_url.scheme != "https":
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail=(
|
||||
"Canvas pagination link must use https, "
|
||||
f"got {parsed_url.scheme!r}"
|
||||
),
|
||||
)
|
||||
return url
|
||||
return None
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Return the Authorization header with the bearer token."""
|
||||
return {"Authorization": f"Bearer {self._bearer_token}"}
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
"""Build a full Canvas API URL from an endpoint path.
|
||||
|
||||
Assumes endpoint is non-empty (e.g. ``"courses"``, ``"announcements"``).
|
||||
Only called on a first request, endpoint must be set for first request.
|
||||
Verify endpoint exists in case of future changes where endpoint might be optional.
|
||||
Leading slashes are stripped to avoid double-slash in the result.
|
||||
self.base_url is already normalized with no trailing slash.
|
||||
"""
|
||||
final_url = self.base_url
|
||||
clean_endpoint = endpoint.lstrip("/")
|
||||
if clean_endpoint:
|
||||
final_url += "/" + clean_endpoint
|
||||
return final_url
|
||||
74
backend/onyx/connectors/canvas/connector.py
Normal file
74
backend/onyx/connectors/canvas/connector.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Literal
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
course_code: str
|
||||
created_at: str
|
||||
workflow_state: str
|
||||
|
||||
|
||||
class CanvasPage(BaseModel):
|
||||
page_id: int
|
||||
url: str
|
||||
title: str
|
||||
body: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
course_id: int
|
||||
|
||||
|
||||
class CanvasAssignment(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str | None = None
|
||||
html_url: str
|
||||
course_id: int
|
||||
created_at: str
|
||||
updated_at: str
|
||||
due_at: str | None = None
|
||||
|
||||
|
||||
class CanvasAnnouncement(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
message: str | None = None
|
||||
html_url: str
|
||||
posted_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
|
||||
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint state for resumable Canvas indexing.
|
||||
|
||||
Fields:
|
||||
course_ids: Materialized list of course IDs to process.
|
||||
current_course_index: Index into course_ids for current course.
|
||||
stage: Which item type we're processing for the current course.
|
||||
next_url: Pagination cursor within the current stage. None means
|
||||
start from the first page; a URL means resume from that page.
|
||||
|
||||
Invariant:
|
||||
If current_course_index is incremented, stage must be reset to
|
||||
"pages" and next_url must be reset to None.
|
||||
"""
|
||||
|
||||
course_ids: list[int] = []
|
||||
current_course_index: int = 0
|
||||
stage: CanvasStage = "pages"
|
||||
next_url: str | None = None
|
||||
|
||||
def advance_course(self) -> None:
|
||||
"""Move to the next course and reset within-course state."""
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
@@ -10,6 +10,7 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
@@ -239,29 +240,53 @@ def enhanced_search_ids(
|
||||
)
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
def _bulk_fetch_request(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
|
||||
f"be decoded as JSON (response too large or truncated)."
|
||||
)
|
||||
raise
|
||||
|
||||
mid = len(issue_ids) // 2
|
||||
logger.warning(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in response["issues"]
|
||||
for issue in raw_issues
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -617,92 +617,6 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -925,8 +839,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -3135,8 +3135,6 @@ class VoiceProvider(Base):
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
@@ -119,10 +108,10 @@ def upsert_voice_provider(
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class LlmProviderNames(str, Enum):
|
||||
LM_STUDIO = "lm_studio"
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -44,6 +45,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
]
|
||||
|
||||
|
||||
@@ -61,6 +63,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -112,6 +115,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -290,6 +290,17 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
self._custom_llm_provider = "openai"
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
model_kwargs["api_base"] = self._api_base
|
||||
|
||||
# This is needed for Ollama to do proper function calling
|
||||
if model_provider == LlmProviderNames.OLLAMA_CHAT and api_base is not None:
|
||||
model_kwargs["api_base"] = api_base
|
||||
@@ -401,14 +412,20 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
model = (
|
||||
f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
)
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
|
||||
# Tool choice
|
||||
if is_claude_model and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
@@ -483,10 +500,11 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if not (is_claude_model or is_ollama or is_mistral):
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included.
|
||||
# so it must be conditionally included unless the request is
|
||||
# routed through Bifrost's OpenAI-compatible endpoint.
|
||||
# Additionally, tool_choice is not supported by Ollama and causes warnings if included.
|
||||
# See also, https://github.com/ollama/ollama/issues/11171
|
||||
optional_kwargs["allowed_openai_params"] = ["tool_choice"]
|
||||
|
||||
@@ -8,24 +8,6 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -13,6 +13,8 @@ LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
BIFROST_PROVIDER_NAME = "bifrost"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
@@ -49,6 +50,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -44,11 +44,12 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -122,9 +123,8 @@ def _validate_endpoint(
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) →
|
||||
timeout (operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
@@ -141,19 +141,11 @@ def _validate_endpoint(
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
# Any timeout (connect, read, or write) means the configured timeout_seconds
|
||||
# is too low for this endpoint. Report as timeout so the UI directs the user
|
||||
# to increase the timeout setting.
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
"Hook endpoint validation: timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
@@ -57,6 +57,8 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import BifrostFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BifrostModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
@@ -1422,11 +1424,26 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="LiteLLM proxy",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_openai_compatible_models_response(
|
||||
url: str,
|
||||
source_name: str,
|
||||
api_key: str | None = None,
|
||||
) -> dict:
|
||||
"""Fetch model metadata from an OpenAI-compatible `/models` endpoint."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if not api_key:
|
||||
headers.pop("Authorization")
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
@@ -1436,20 +1453,125 @@ def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
f"Authentication failed: invalid or missing API key for {source_name}.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"LiteLLM models endpoint not found at {url}. Please verify the API base URL.",
|
||||
f"{source_name} models endpoint not found at {url}. Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
except httpx.RequestError as e:
|
||||
logger.warning(
|
||||
"Failed to fetch models from OpenAI-compatible endpoint",
|
||||
extra={"source": source_name, "url": url, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Received invalid model response from OpenAI-compatible endpoint",
|
||||
extra={"source": source_name, "url": url, "error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch {source_name} models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/bifrost/available-models")
|
||||
def get_bifrost_available_models(
|
||||
request: BifrostModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Bifrost endpoint",
|
||||
)
|
||||
|
||||
results: list[BifrostFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
BifrostFinalModelResponse(
|
||||
name=model_id,
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Bifrost model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Bifrost",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Bifrost",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> dict:
|
||||
"""Perform GET to Bifrost /v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
# Ensure we hit /v1/models
|
||||
if cleaned_api_base.endswith("/v1"):
|
||||
url = f"{cleaned_api_base}/models"
|
||||
else:
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="Bifrost",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
@@ -449,3 +449,18 @@ class LitellmModelDetails(BaseModel):
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
|
||||
# Bifrost dynamic models fetch
|
||||
class BifrostModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class BifrostFinalModelResponse(BaseModel):
|
||||
name: str # Model ID in provider/model format (e.g. "anthropic/claude-sonnet-4-6")
|
||||
display_name: str # Human-readable name from Bifrost API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -25,6 +25,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.BIFROST,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -50,6 +51,25 @@ BEDROCK_VISION_MODELS = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
# Known Bifrost/OpenAI-compatible vision-capable model families where the
|
||||
# source API does not expose this metadata directly.
|
||||
BIFROST_VISION_MODEL_FAMILIES = frozenset(
|
||||
{
|
||||
"anthropic/claude-3",
|
||||
"anthropic/claude-4",
|
||||
"amazon/nova-pro",
|
||||
"amazon/nova-lite",
|
||||
"amazon/nova-premier",
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4.1",
|
||||
"google/gemini",
|
||||
"meta-llama/llama-3.2",
|
||||
"mistral/pixtral",
|
||||
"qwen/qwen2.5-vl",
|
||||
"qwen/qwen-vl",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_valid_bedrock_model(
|
||||
model_id: str,
|
||||
@@ -76,11 +96,18 @@ def is_valid_bedrock_model(
|
||||
def infer_vision_support(model_id: str) -> bool:
|
||||
"""Infer vision support from model ID when base model metadata unavailable.
|
||||
|
||||
Used for cross-region inference profiles when the base model isn't
|
||||
available in the user's region.
|
||||
Used for providers like Bedrock and Bifrost where vision support may
|
||||
need to be inferred from vendor/model naming conventions.
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
return any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS)
|
||||
if any(vision_model in model_id_lower for vision_model in BEDROCK_VISION_MODELS):
|
||||
return True
|
||||
|
||||
normalized_model_id = model_id_lower.replace(".", "/")
|
||||
return any(
|
||||
vision_model in normalized_model_id
|
||||
for vision_model in BIFROST_VISION_MODEL_FAMILIES
|
||||
)
|
||||
|
||||
|
||||
def generate_bedrock_display_name(model_id: str) -> str:
|
||||
@@ -322,7 +349,7 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
- Ollama: "llama3:70b" → "Meta"
|
||||
- Ollama: "qwen2.5:7b" → "Alibaba"
|
||||
"""
|
||||
if provider == LlmProviderNames.OPENROUTER:
|
||||
if provider in (LlmProviderNames.OPENROUTER, LlmProviderNames.BIFROST):
|
||||
# Format: "vendor/model-name" e.g., "anthropic/claude-3-5-sonnet"
|
||||
if "/" in model_name:
|
||||
vendor_key = model_name.split("/")[0].lower()
|
||||
|
||||
@@ -449,40 +449,128 @@ class RedisHealthCollector(_CachedCollector):
|
||||
return [memory_used, memory_peak, memory_frag, connected_clients]
|
||||
|
||||
|
||||
class WorkerHealthCollector(_CachedCollector):
|
||||
"""Collects Celery worker count and process count via inspect ping.
|
||||
class WorkerHeartbeatMonitor:
|
||||
"""Monitors Celery worker health via the event stream.
|
||||
|
||||
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
|
||||
command that takes a couple seconds to complete.
|
||||
|
||||
Maintains a set of known worker short-names so that when a worker
|
||||
stops responding, we emit ``up=0`` instead of silently dropping the
|
||||
metric (which would make ``absent()``-style alerts impossible).
|
||||
Subscribes to ``worker-heartbeat``, ``worker-online``, and
|
||||
``worker-offline`` events via a single persistent connection.
|
||||
Runs in a daemon thread started once during worker setup.
|
||||
"""
|
||||
|
||||
# Remove a worker from _known_workers after this many consecutive
|
||||
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
|
||||
_MAX_CONSECUTIVE_MISSES = 10
|
||||
# Consider a worker down if no heartbeat received for this long.
|
||||
_HEARTBEAT_TIMEOUT_SECONDS = 120.0
|
||||
|
||||
def __init__(self, cache_ttl: float = 60.0) -> None:
|
||||
def __init__(self, celery_app: Any) -> None:
|
||||
self._app = celery_app
|
||||
self._worker_last_seen: dict[str, float] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background event listener thread.
|
||||
|
||||
Safe to call multiple times — only starts one thread.
|
||||
"""
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._listen, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info("WorkerHeartbeatMonitor started")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
def _listen(self) -> None:
|
||||
"""Background loop: connect to event stream and process heartbeats."""
|
||||
while self._running:
|
||||
try:
|
||||
with self._app.connection() as conn:
|
||||
recv = self._app.events.Receiver(
|
||||
conn,
|
||||
handlers={
|
||||
"worker-heartbeat": self._on_heartbeat,
|
||||
"worker-online": self._on_heartbeat,
|
||||
"worker-offline": self._on_offline,
|
||||
},
|
||||
)
|
||||
recv.capture(
|
||||
limit=None, timeout=self._HEARTBEAT_TIMEOUT_SECONDS, wakeup=True
|
||||
)
|
||||
except Exception:
|
||||
if self._running:
|
||||
logger.debug(
|
||||
"Heartbeat listener disconnected, reconnecting in 5s",
|
||||
exc_info=True,
|
||||
)
|
||||
time.sleep(5.0)
|
||||
else:
|
||||
# capture() returned normally (timeout with no events); reconnect
|
||||
if self._running:
|
||||
logger.debug("Heartbeat capture timed out, reconnecting")
|
||||
time.sleep(5.0)
|
||||
|
||||
def _on_heartbeat(self, event: dict[str, Any]) -> None:
|
||||
hostname = event.get("hostname")
|
||||
if hostname:
|
||||
with self._lock:
|
||||
self._worker_last_seen[hostname] = time.monotonic()
|
||||
|
||||
def _on_offline(self, event: dict[str, Any]) -> None:
|
||||
hostname = event.get("hostname")
|
||||
if hostname:
|
||||
with self._lock:
|
||||
self._worker_last_seen.pop(hostname, None)
|
||||
|
||||
def get_worker_status(self) -> dict[str, bool]:
|
||||
"""Return {hostname: is_alive} for all known workers.
|
||||
|
||||
Thread-safe. Called by WorkerHealthCollector on each scrape.
|
||||
Also prunes workers that have been dead longer than 2x the
|
||||
heartbeat timeout to prevent unbounded growth.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
prune_threshold = self._HEARTBEAT_TIMEOUT_SECONDS * 2
|
||||
with self._lock:
|
||||
# Prune workers that have been gone for 2x the timeout
|
||||
stale = [
|
||||
h
|
||||
for h, ts in self._worker_last_seen.items()
|
||||
if (now - ts) > prune_threshold
|
||||
]
|
||||
for h in stale:
|
||||
del self._worker_last_seen[h]
|
||||
|
||||
result: dict[str, bool] = {}
|
||||
for hostname, last_seen in self._worker_last_seen.items():
|
||||
alive = (now - last_seen) < self._HEARTBEAT_TIMEOUT_SECONDS
|
||||
result[hostname] = alive
|
||||
return result
|
||||
|
||||
|
||||
class WorkerHealthCollector(_CachedCollector):
|
||||
"""Collects Celery worker health from the heartbeat monitor.
|
||||
|
||||
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
|
||||
to the Celery event stream via a single persistent connection.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = 30.0) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
# worker short-name → consecutive miss count.
|
||||
# Workers start at 0 and reset to 0 each time they respond.
|
||||
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
|
||||
self._known_workers: dict[str, int] = {}
|
||||
self._monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app instance for inspect commands."""
|
||||
self._celery_app = app
|
||||
def set_monitor(self, monitor: WorkerHeartbeatMonitor) -> None:
|
||||
"""Set the heartbeat monitor instance."""
|
||||
self._monitor = monitor
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._celery_app is None:
|
||||
if self._monitor is None:
|
||||
return []
|
||||
|
||||
active_workers = GaugeMetricFamily(
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers responding to ping",
|
||||
"Number of active Celery workers with recent heartbeats",
|
||||
)
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
@@ -491,37 +579,15 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
)
|
||||
|
||||
try:
|
||||
inspector = self._celery_app.control.inspect(timeout=3.0)
|
||||
ping_result = inspector.ping()
|
||||
status = self._monitor.get_worker_status()
|
||||
alive_count = sum(1 for alive in status.values() if alive)
|
||||
active_workers.add_metric([], alive_count)
|
||||
|
||||
responding: set[str] = set()
|
||||
if ping_result:
|
||||
active_workers.add_metric([], len(ping_result))
|
||||
for worker_name in ping_result:
|
||||
# Strip hostname suffix for cleaner labels
|
||||
short_name = worker_name.split("@")[0]
|
||||
responding.add(short_name)
|
||||
else:
|
||||
active_workers.add_metric([], 0)
|
||||
|
||||
# Register newly-seen workers and reset miss count for
|
||||
# workers that responded.
|
||||
for short_name in responding:
|
||||
self._known_workers[short_name] = 0
|
||||
|
||||
# Increment miss count for non-responding workers and evict
|
||||
# those that have been missing too long.
|
||||
stale = []
|
||||
for short_name in list(self._known_workers):
|
||||
if short_name not in responding:
|
||||
self._known_workers[short_name] += 1
|
||||
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
|
||||
stale.append(short_name)
|
||||
for short_name in stale:
|
||||
del self._known_workers[short_name]
|
||||
|
||||
for short_name in sorted(self._known_workers):
|
||||
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
|
||||
for hostname in sorted(status):
|
||||
# Use short name (before @) for single-host deployments,
|
||||
# full hostname when multiple hosts share a worker type.
|
||||
label = hostname.split("@")[0]
|
||||
worker_up.add_metric([label], 1 if status[hostname] else 0)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -28,6 +29,7 @@ _attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
@@ -96,7 +98,16 @@ def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_worker_health_collector.set_celery_app(celery_app)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
# Module-level singleton prevents duplicate threads on re-entry.
|
||||
global _heartbeat_monitor
|
||||
if _heartbeat_monitor is None:
|
||||
_heartbeat_monitor = WorkerHeartbeatMonitor(celery_app)
|
||||
_heartbeat_monitor.start()
|
||||
_worker_health_collector.set_monitor(_heartbeat_monitor)
|
||||
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -62,8 +60,6 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -85,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -575,46 +570,6 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -705,30 +660,6 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -2,24 +2,11 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to in a multi-model comparison
|
||||
(0, 1, or 2). ``None`` for single-model responses, preserving the
|
||||
backwards-compatible wire format for existing API consumers.
|
||||
"""
|
||||
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -708,6 +708,7 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -743,7 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
emitter = Emitter()
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -790,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {emitter.bus.qsize()}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
|
||||
@@ -103,6 +103,11 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="no yuhong allowed",
|
||||
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
Covers both modes (standalone and streaming) without any real database,
|
||||
LLM, or queue infrastructure beyond the stdlib Queue.
|
||||
"""
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Standalone mode (no merged_queue)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterStandaloneMode:
|
||||
def test_emitted_packet_arrives_on_bus(self) -> None:
|
||||
emitter = Emitter()
|
||||
pkt = _packet()
|
||||
emitter.emit(pkt)
|
||||
assert emitter.bus.get_nowait() is pkt
|
||||
|
||||
def test_bus_is_empty_before_emit(self) -> None:
|
||||
emitter = Emitter()
|
||||
assert emitter.bus.empty()
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter = Emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
assert emitter.bus.get_nowait() is p1
|
||||
assert emitter.bus.get_nowait() is p2
|
||||
|
||||
def test_packet_not_modified(self) -> None:
|
||||
"""Standalone mode must not wrap or mutate the packet."""
|
||||
emitter = Emitter()
|
||||
pkt = _packet(turn_index=7, tab_index=3)
|
||||
emitter.emit(pkt)
|
||||
retrieved = emitter.bus.get_nowait()
|
||||
assert retrieved.placement.turn_index == 7
|
||||
assert retrieved.placement.tab_index == 3
|
||||
|
||||
def test_get_default_emitter_is_standalone(self) -> None:
|
||||
emitter = get_default_emitter()
|
||||
pkt = _packet()
|
||||
emitter.emit(pkt)
|
||||
# Packet lands on the bus, not a shared queue
|
||||
assert emitter.bus.get_nowait() is pkt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming mode (merged_queue provided)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterStreamingMode:
|
||||
# --- Queue routing ---
|
||||
|
||||
def test_packet_goes_to_merged_queue_not_bus(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
assert emitter.bus.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=1, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
# --- model_index tagging ---
|
||||
|
||||
def test_model_idx_none_preserves_model_index_none(self) -> None:
|
||||
"""N=1 backwards-compat: model_index must stay None in the packet."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=None, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index is None
|
||||
|
||||
def test_model_idx_zero_tags_packet(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=1, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=2, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
# --- Queue key ---
|
||||
|
||||
def test_key_equals_model_idx_when_set(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=2, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_key_is_zero_when_model_idx_none(self) -> None:
|
||||
"""N=1: key defaults to 0 (single slot in the drain loop)."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=None, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
# --- Placement field preservation ---
|
||||
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
|
||||
# --- bus is always created ---
|
||||
|
||||
def test_bus_exists_in_streaming_mode(self) -> None:
|
||||
"""bus must always be present for backwards-compat with existing callers."""
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
assert hasattr(emitter, "bus")
|
||||
assert isinstance(emitter.bus, queue.Queue)
|
||||
|
||||
def test_bus_stays_empty_in_streaming_mode(self) -> None:
|
||||
import queue
|
||||
|
||||
mq: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(model_idx=0, merged_queue=mq)
|
||||
emitter.emit(_packet())
|
||||
assert emitter.bus.empty()
|
||||
@@ -1,636 +0,0 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _start_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> None:
|
||||
"""Advance the generator one step to trigger early validation."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
# Calling next() executes until the first yield OR raises.
|
||||
# Validation errors are raised before any yield.
|
||||
next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_raises(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
|
||||
def test_four_overrides_raises(self) -> None:
|
||||
"""4 overrides exceeds maximum — must raise."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_zero_overrides_raises(self) -> None:
|
||||
"""Empty override list raises."""
|
||||
req = _make_request()
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [])
|
||||
|
||||
def test_deep_research_raises(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model."""
|
||||
req = _make_request(deep_research=True)
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override fails, 2 passes — ensures fence-post is correct."""
|
||||
req = _make_request()
|
||||
# 1 override must fail
|
||||
with pytest.raises(ValueError, match="2-3"):
|
||||
_start_stream(req, [_make_override()])
|
||||
# 2 overrides must NOT raise ValueError (may raise later due to missing session, that's OK)
|
||||
try:
|
||||
_start_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
except ValueError as exc:
|
||||
pytest.fail(f"2 overrides should pass validation, got ValueError: {exc}")
|
||||
except Exception:
|
||||
pass # Any other error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
_RUN_MODELS_PATCHES = [
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
|
||||
]
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_none(self) -> None:
|
||||
"""Single-model path: model_index stays None for wire backwards-compat."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index is None
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
emitter = kwargs["emitter"]
|
||||
# _model_idx is None for N=1, int for N>1
|
||||
if emitter._model_idx == 0:
|
||||
raise RuntimeError("model 0 failed")
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_completion_handle_called_on_disconnect(self) -> None:
|
||||
"""llm_loop_completion_handle must still be called even when user disconnects.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers wait+completion in finally.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator, which propagates into _run_models. The finally
|
||||
block must call executor.shutdown(wait=True) to wait for LLM threads to
|
||||
finish, then persist their results via llm_loop_completion_handle.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
thread_completed = threading.Event()
|
||||
|
||||
def emit_then_complete(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give generator a yield point), then finish."""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
# Small sleep so executor.shutdown(wait=True) in finally actually waits.
|
||||
time.sleep(0.05)
|
||||
thread_completed.set()
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_complete,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
gen = _run_models(setup, MagicMock(), MagicMock())
|
||||
# Advance to the first yielded packet — generator suspends at `yield item`.
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# GeneratorExit is thrown at the `yield item` suspension point.
|
||||
gen.close()
|
||||
|
||||
# Finally block must have waited for the thread and saved completion.
|
||||
assert (
|
||||
thread_completed.is_set()
|
||||
), "LLM thread must complete before gen.close() returns"
|
||||
assert (
|
||||
mock_handle.call_count == 1
|
||||
), "completion handle must be called for the successful model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
@@ -0,0 +1,381 @@
|
||||
"""Tests for Canvas connector — client (PR1)."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FAKE_BASE_URL = "https://myschool.instructure.com"
|
||||
FAKE_TOKEN = "fake-canvas-token"
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
json_data: Any = None,
|
||||
link_header: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock HTTP response with status, json, and Link header."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.reason = "OK" if status_code < 300 else "Error"
|
||||
resp.json.return_value = json_data if json_data is not None else []
|
||||
resp.headers = {"Link": link_header}
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient.__init__ tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCanvasApiClientInit:
|
||||
def test_success(self) -> None:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
|
||||
expected_host = "myschool.instructure.com"
|
||||
|
||||
assert client.base_url == expected_base_url
|
||||
assert client._expected_host == expected_host
|
||||
|
||||
def test_normalizes_trailing_slash(self) -> None:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=f"{FAKE_BASE_URL}/",
|
||||
)
|
||||
|
||||
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
|
||||
|
||||
assert client.base_url == expected_base_url
|
||||
|
||||
def test_normalizes_existing_api_v1(self) -> None:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=f"{FAKE_BASE_URL}/api/v1",
|
||||
)
|
||||
|
||||
expected_base_url = f"{FAKE_BASE_URL}/api/v1"
|
||||
|
||||
assert client.base_url == expected_base_url
|
||||
|
||||
def test_rejects_non_https_scheme(self) -> None:
|
||||
with pytest.raises(ValueError, match="must use https"):
|
||||
CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url="ftp://myschool.instructure.com",
|
||||
)
|
||||
|
||||
def test_rejects_http(self) -> None:
|
||||
with pytest.raises(ValueError, match="must use https"):
|
||||
CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url="http://myschool.instructure.com",
|
||||
)
|
||||
|
||||
def test_rejects_missing_host(self) -> None:
|
||||
with pytest.raises(ValueError, match="must include a valid host"):
|
||||
CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url="https://",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._build_url tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildUrl:
|
||||
def setup_method(self) -> None:
|
||||
self.client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
def test_appends_endpoint(self) -> None:
|
||||
result = self.client._build_url("courses")
|
||||
expected = f"{FAKE_BASE_URL}/api/v1/courses"
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_strips_leading_slash_from_endpoint(self) -> None:
|
||||
result = self.client._build_url("/courses")
|
||||
expected = f"{FAKE_BASE_URL}/api/v1/courses"
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._build_headers tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildHeaders:
|
||||
def setup_method(self) -> None:
|
||||
self.client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
def test_returns_bearer_auth(self) -> None:
|
||||
result = self.client._build_headers()
|
||||
expected = {"Authorization": f"Bearer {FAKE_TOKEN}"}
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient.get tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGet:
|
||||
def setup_method(self) -> None:
|
||||
self.client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_success_returns_json_and_next_url(self, mock_requests: MagicMock) -> None:
|
||||
next_link = f"<{FAKE_BASE_URL}/api/v1/courses?page=2>; " 'rel="next"'
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[{"id": 1}], link_header=next_link
|
||||
)
|
||||
|
||||
data, next_url = self.client.get("courses")
|
||||
|
||||
expected_data = [{"id": 1}]
|
||||
expected_next = f"{FAKE_BASE_URL}/api/v1/courses?page=2"
|
||||
|
||||
assert data == expected_data
|
||||
assert next_url == expected_next
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_success_no_next_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[{"id": 1}])
|
||||
|
||||
data, next_url = self.client.get("courses")
|
||||
|
||||
assert data == [{"id": 1}]
|
||||
assert next_url is None
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_raises_on_error_status(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(403, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_raises_on_404(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(404, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_raises_on_429(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(429, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_skips_params_when_using_full_url(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
full = f"{FAKE_BASE_URL}/api/v1/courses?page=2"
|
||||
|
||||
self.client.get(params={"per_page": "100"}, full_url=full)
|
||||
|
||||
_, kwargs = mock_requests.get.call_args
|
||||
assert kwargs["params"] is None
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_error_dict(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 1: {"error": {"message": "Not authorized"}}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
403, {"error": {"message": "Not authorized"}}
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Not authorized"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_error_string(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 2: {"error": "Invalid access token"}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
401, {"error": "Invalid access token"}
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Invalid access token"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_extracts_message_from_errors_list(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Shape 3: {"errors": [{"message": "Invalid query"}]}"""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
400, {"errors": [{"message": "Invalid query"}]}
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Invalid query"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_dict_takes_priority_over_errors_list(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""When both error shapes are present, error dict wins."""
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
403, {"error": "Specific error", "errors": [{"message": "Generic"}]}
|
||||
)
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Specific error"
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_error_falls_back_to_reason_when_no_json_message(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Empty error body falls back to response.reason."""
|
||||
mock_requests.get.return_value = _mock_response(500, {})
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Error" # from _mock_response's reason for >= 300
|
||||
|
||||
assert result == expected
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_invalid_json_on_success_raises(self, mock_requests: MagicMock) -> None:
|
||||
"""Invalid JSON on a 2xx response raises OnyxError."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
resp.headers = {"Link": ""}
|
||||
mock_requests.get.return_value = resp
|
||||
|
||||
with pytest.raises(OnyxError, match="Invalid JSON"):
|
||||
self.client.get("courses")
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_invalid_json_on_error_falls_back_to_reason(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
"""Invalid JSON on a 4xx response falls back to response.reason."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = 500
|
||||
resp.reason = "Internal Server Error"
|
||||
resp.json.side_effect = ValueError("No JSON")
|
||||
resp.headers = {"Link": ""}
|
||||
mock_requests.get.return_value = resp
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self.client.get("courses")
|
||||
|
||||
result = exc_info.value.detail
|
||||
expected = "Internal Server Error"
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._parse_next_link tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseNextLink:
|
||||
def setup_method(self) -> None:
|
||||
self.client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url="https://canvas.example.com",
|
||||
)
|
||||
|
||||
def test_found(self) -> None:
|
||||
header = '<https://canvas.example.com/api/v1/courses?page=2>; rel="next"'
|
||||
|
||||
result = self.client._parse_next_link(header)
|
||||
expected = "https://canvas.example.com/api/v1/courses?page=2"
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_not_found(self) -> None:
|
||||
header = '<https://canvas.example.com/api/v1/courses?page=1>; rel="current"'
|
||||
|
||||
result = self.client._parse_next_link(header)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_empty(self) -> None:
|
||||
result = self.client._parse_next_link("")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_multiple_rels(self) -> None:
|
||||
header = (
|
||||
'<https://canvas.example.com/api/v1/courses?page=1>; rel="current", '
|
||||
'<https://canvas.example.com/api/v1/courses?page=2>; rel="next"'
|
||||
)
|
||||
|
||||
result = self.client._parse_next_link(header)
|
||||
expected = "https://canvas.example.com/api/v1/courses?page=2"
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_rejects_host_mismatch(self) -> None:
|
||||
header = '<https://evil.example.com/api/v1/courses?page=2>; rel="next"'
|
||||
|
||||
with pytest.raises(OnyxError, match="unexpected host"):
|
||||
self.client._parse_next_link(header)
|
||||
|
||||
def test_rejects_non_https_link(self) -> None:
|
||||
header = '<http://canvas.example.com/api/v1/courses?page=2>; rel="next"'
|
||||
|
||||
with pytest.raises(OnyxError, match="must use https"):
|
||||
self.client._parse_next_link(header)
|
||||
147
backend/tests/unit/onyx/connectors/jira/test_jira_bulk_fetch.py
Normal file
147
backend/tests/unit/onyx/connectors/jira/test_jira_bulk_fetch.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": issue_id,
|
||||
"key": f"TEST-{issue_id}",
|
||||
"fields": {"summary": f"Issue {issue_id}"},
|
||||
}
|
||||
|
||||
|
||||
def _mock_jira_client() -> MagicMock:
|
||||
mock = MagicMock(spec=JIRA)
|
||||
mock._options = {"server": "https://jira.example.com"}
|
||||
mock._session = MagicMock()
|
||||
mock._get_url = MagicMock(
|
||||
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def test_bulk_fetch_success() -> None:
|
||||
"""Happy path: all issues fetched in one request."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3"])
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(r, Issue) for r in result)
|
||||
client._session.post.assert_called_once()
|
||||
|
||||
|
||||
def test_bulk_fetch_splits_on_json_error() -> None:
|
||||
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if len(ids) > 2:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 2294125
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
|
||||
assert len(result) == 4
|
||||
returned_ids = {r.raw["id"] for r in result}
|
||||
assert returned_ids == {"1", "2", "3", "4"}
|
||||
assert call_count > 1
|
||||
|
||||
|
||||
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
|
||||
"""A single issue that always fails JSON decode raises after splitting."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if "bad" in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 100
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "bad", "2"])
|
||||
|
||||
|
||||
def test_bulk_fetch_non_json_error_propagates() -> None:
|
||||
"""Non-JSONDecodeError exceptions still propagate."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = ValueError("something else broke")
|
||||
client._session.post.return_value = resp
|
||||
|
||||
try:
|
||||
bulk_fetch_issues(client, ["1"])
|
||||
assert False, "Expected ValueError to propagate"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_bulk_fetch_with_fields() -> None:
|
||||
"""Fields parameter is forwarded correctly."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
bulk_fetch_issues(client, ["1"], fields="summary,description")
|
||||
|
||||
call_payload = client._session.post.call_args[1]["json"]
|
||||
assert call_payload["fields"] == ["summary", "description"]
|
||||
|
||||
|
||||
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
|
||||
client = _mock_jira_client()
|
||||
bad_id = "BAD"
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if bad_id in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"truncated", "doc", 999
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
assert provider.deleted is True
|
||||
mock_db_session.delete.assert_called_once_with(provider)
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
|
||||
@@ -1462,3 +1462,69 @@ def test_no_tool_choice_sent_when_no_tools(default_multi_llm: LitellmLLM) -> Non
|
||||
assert (
|
||||
"tool_choice" not in kwargs
|
||||
), "tool_choice must not be sent to providers when no tools are provided"
|
||||
|
||||
|
||||
def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
api_base="https://bifrost.example.com/",
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BIFROST,
|
||||
model_name="anthropic/claude-sonnet-4-6",
|
||||
max_input_tokens=32000,
|
||||
)
|
||||
|
||||
assert llm._custom_llm_provider == "openai"
|
||||
assert llm._api_base == "https://bifrost.example.com/v1"
|
||||
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
|
||||
|
||||
|
||||
def test_bifrost_claude_includes_allowed_openai_params() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
api_base="https://bifrost.example.com",
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BIFROST,
|
||||
model_name="anthropic/claude-sonnet-4-6",
|
||||
max_input_tokens=32000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Use a tool if needed")]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup",
|
||||
"description": "Look up data",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Done"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="anthropic/claude-sonnet-4-6",
|
||||
),
|
||||
]
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
llm.invoke(messages, tools=tools)
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert kwargs["model"] == "anthropic/claude-sonnet-4-6"
|
||||
assert kwargs["base_url"] == "https://bifrost.example.com/v1"
|
||||
assert kwargs["custom_llm_provider"] == "openai"
|
||||
assert kwargs["allowed_openai_params"] == ["tool_choice"]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,13 +158,11 @@ class TestValidateEndpoint:
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -12,6 +12,8 @@ import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.models import BifrostFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BifrostModelsRequest
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
@@ -850,13 +852,15 @@ class TestGetLitellmAvailableModels:
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
mock_get.side_effect = httpx.ConnectError(
|
||||
"Connection refused", request=MagicMock()
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM proxy models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_401_raises_authentication_error(self) -> None:
|
||||
@@ -898,3 +902,113 @@ class TestGetLitellmAvailableModels:
|
||||
)
|
||||
with pytest.raises(OnyxError, match="endpoint not found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
|
||||
class TestGetBifrostAvailableModels:
|
||||
"""Tests for the Bifrost model fetch endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bifrost_response(self) -> dict:
|
||||
"""Mock response from Bifrost /v1/models endpoint."""
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "anthropic/claude-3-5-sonnet",
|
||||
"name": "Claude 3.5 Sonnet",
|
||||
"context_length": 200000,
|
||||
},
|
||||
{
|
||||
"id": "openai/gpt-4o",
|
||||
"name": "GPT-4o",
|
||||
"context_length": 128000,
|
||||
},
|
||||
{
|
||||
"id": "deepseek/deepseek-r1",
|
||||
"name": "DeepSeek R1",
|
||||
"context_length": 64000,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def test_returns_model_list(self, mock_bifrost_response: dict) -> None:
|
||||
"""Test that endpoint returns properly formatted non-embedding models."""
|
||||
from onyx.server.manage.llm.api import get_bifrost_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_bifrost_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
|
||||
results = get_bifrost_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(r, BifrostFinalModelResponse) for r in results)
|
||||
assert [r.name for r in results] == sorted(
|
||||
[r.name for r in results], key=str.lower
|
||||
)
|
||||
|
||||
def test_infers_vision_support(self, mock_bifrost_response: dict) -> None:
|
||||
"""Test that vision support is inferred from provider/model IDs."""
|
||||
from onyx.server.manage.llm.api import get_bifrost_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_bifrost_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
|
||||
results = get_bifrost_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
claude = next(r for r in results if r.name == "anthropic/claude-3-5-sonnet")
|
||||
gpt4o = next(r for r in results if r.name == "openai/gpt-4o")
|
||||
deepseek = next(r for r in results if r.name == "deepseek/deepseek-r1")
|
||||
|
||||
assert claude.supports_image_input is True
|
||||
assert gpt4o.supports_image_input is True
|
||||
assert deepseek.supports_image_input is False
|
||||
|
||||
def test_existing_v1_suffix_is_not_duplicated(self) -> None:
|
||||
"""Test that an existing /v1 suffix still hits a single /v1/models endpoint."""
|
||||
from onyx.server.manage.llm.api import get_bifrost_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response = {"data": [{"id": "openai/gpt-4o", "name": "GPT-4o"}]}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = BifrostModelsRequest(api_base="https://bifrost.example.com/v1")
|
||||
get_bifrost_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
called_url = mock_get.call_args[0][0]
|
||||
assert called_url == "https://bifrost.example.com/v1/models"
|
||||
|
||||
def test_request_failure_is_logged_and_wrapped(self) -> None:
|
||||
"""Test that request-layer failures are logged before raising OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_bifrost_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with (
|
||||
patch("onyx.server.manage.llm.api.httpx.get") as mock_get,
|
||||
patch("onyx.server.manage.llm.api.logger.warning") as mock_warning,
|
||||
):
|
||||
mock_get.side_effect = httpx.ConnectError(
|
||||
"Connection refused", request=MagicMock()
|
||||
)
|
||||
|
||||
request = BifrostModelsRequest(api_base="https://bifrost.example.com")
|
||||
with pytest.raises(OnyxError, match="Failed to fetch Bifrost models"):
|
||||
get_bifrost_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
@@ -176,6 +176,14 @@ class TestInferVisionSupport:
|
||||
"""Test Nova Pro has vision."""
|
||||
assert infer_vision_support("amazon.nova-pro-v1") is True
|
||||
|
||||
def test_bifrost_claude_has_vision(self) -> None:
|
||||
"""Test Bifrost Claude models are recognized as vision-capable."""
|
||||
assert infer_vision_support("anthropic/claude-3-5-sonnet") is True
|
||||
|
||||
def test_bifrost_gpt4o_has_vision(self) -> None:
|
||||
"""Test Bifrost GPT-4o models are recognized as vision-capable."""
|
||||
assert infer_vision_support("openai/gpt-4o") is True
|
||||
|
||||
def test_mistral_no_vision(self) -> None:
|
||||
"""Test Mistral doesn't have vision (not in known list)."""
|
||||
assert infer_vision_support("mistral.mistral-large") is False
|
||||
|
||||
170
backend/tests/unit/server/metrics/test_worker_health.py
Normal file
170
backend/tests/unit/server/metrics/test_worker_health.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Tests for WorkerHeartbeatMonitor and WorkerHealthCollector."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHeartbeatMonitor
|
||||
|
||||
|
||||
class TestWorkerHeartbeatMonitor:
|
||||
def test_heartbeat_registers_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert "primary@host1" in status
|
||||
assert status["primary@host1"] is True
|
||||
|
||||
def test_multiple_workers(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
monitor._on_heartbeat({"hostname": "docfetching@host1"})
|
||||
monitor._on_heartbeat({"hostname": "monitoring@host1"})
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert len(status) == 3
|
||||
assert all(alive for alive in status.values())
|
||||
|
||||
def test_offline_removes_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
monitor._on_offline({"hostname": "primary@host1"})
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert "primary@host1" not in status
|
||||
|
||||
def test_stale_heartbeat_marks_worker_down(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
with monitor._lock:
|
||||
monitor._worker_last_seen["primary@host1"] = (
|
||||
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
|
||||
)
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert status["primary@host1"] is False
|
||||
|
||||
def test_very_stale_worker_is_pruned(self) -> None:
|
||||
"""Workers dead for 2x the timeout are pruned from the dict."""
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
with monitor._lock:
|
||||
monitor._worker_last_seen["gone@host1"] = (
|
||||
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS * 2 - 10
|
||||
)
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert "gone@host1" not in status
|
||||
assert monitor.get_worker_status() == {}
|
||||
|
||||
def test_heartbeat_refreshes_stale_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
with monitor._lock:
|
||||
monitor._worker_last_seen["primary@host1"] = (
|
||||
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
|
||||
)
|
||||
assert monitor.get_worker_status()["primary@host1"] is False
|
||||
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
assert monitor.get_worker_status()["primary@host1"] is True
|
||||
|
||||
def test_ignores_empty_hostname(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({})
|
||||
monitor._on_heartbeat({"hostname": ""})
|
||||
monitor._on_offline({})
|
||||
|
||||
assert monitor.get_worker_status() == {}
|
||||
|
||||
def test_returns_full_hostname_as_key(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@my-long-host.local"})
|
||||
|
||||
status = monitor.get_worker_status()
|
||||
assert "docprocessing@my-long-host.local" in status
|
||||
|
||||
def test_start_is_idempotent(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
# Mock the thread so we don't actually start one
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.is_alive.return_value = True
|
||||
monitor._thread = mock_thread
|
||||
monitor._running = True
|
||||
|
||||
# Second start should be a no-op
|
||||
monitor.start()
|
||||
# Thread constructor should not have been called again
|
||||
assert monitor._thread is mock_thread
|
||||
|
||||
def test_thread_safety(self) -> None:
|
||||
"""get_worker_status should not raise even if heartbeats arrive concurrently."""
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
status = monitor.get_worker_status()
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
status2 = monitor.get_worker_status()
|
||||
assert status == status2
|
||||
|
||||
|
||||
class TestWorkerHealthCollector:
|
||||
def test_returns_empty_when_no_monitor(self) -> None:
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_collects_active_workers(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
monitor._on_heartbeat({"hostname": "docfetching@host1"})
|
||||
monitor._on_heartbeat({"hostname": "monitoring@host1"})
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
families = collector.collect()
|
||||
assert len(families) == 2
|
||||
|
||||
active = families[0]
|
||||
assert active.name == "onyx_celery_active_worker_count"
|
||||
assert active.samples[0].value == 3
|
||||
|
||||
up = families[1]
|
||||
assert up.name == "onyx_celery_worker_up"
|
||||
assert len(up.samples) == 3
|
||||
# Labels use short names (before @)
|
||||
labels = {s.labels["worker"] for s in up.samples}
|
||||
assert labels == {"primary", "docfetching", "monitoring"}
|
||||
for sample in up.samples:
|
||||
assert sample.value == 1
|
||||
|
||||
def test_reports_dead_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
with monitor._lock:
|
||||
monitor._worker_last_seen["monitoring@host1"] = (
|
||||
time.monotonic() - monitor._HEARTBEAT_TIMEOUT_SECONDS - 10
|
||||
)
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
families = collector.collect()
|
||||
active = families[0]
|
||||
assert active.samples[0].value == 1
|
||||
|
||||
up = families[1]
|
||||
samples_by_name = {s.labels["worker"]: s.value for s in up.samples}
|
||||
assert samples_by_name["primary"] == 1
|
||||
assert samples_by_name["monitoring"] == 0
|
||||
|
||||
def test_empty_monitor_returns_zero(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
families = collector.collect()
|
||||
assert len(families) == 2
|
||||
active = families[0]
|
||||
assert active.samples[0].value == 0
|
||||
up = families[1]
|
||||
assert up.name == "onyx_celery_worker_up"
|
||||
assert len(up.samples) == 0
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for memory tool streaming packet emissions."""
|
||||
|
||||
from queue import Queue
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,7 +19,8 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
|
||||
@pytest.fixture
|
||||
def emitter() -> Emitter:
|
||||
return Emitter()
|
||||
bus: Queue = Queue()
|
||||
return Emitter(bus)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
128
greptile.json
128
greptile.json
@@ -1,128 +0,0 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 2,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "greptile.json\n",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"customContext": {
|
||||
"other": [
|
||||
{
|
||||
"scope": [],
|
||||
"content": "Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"content": "Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly."
|
||||
}
|
||||
],
|
||||
"rules": [
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of TODO(name): ... or TODO(1234): ..."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "For frontend changes (changes that touch the /web directory), make sure to enforce all standards described in the web/AGENTS.md file."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Remove temporary debugging code before merging to production, especially tenant-specific debugging logs."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"scope": [],
|
||||
"path": "contributing_guides/best_practices.md",
|
||||
"description": "Best practices for contributing to the codebase"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "backend/alembic/README.md",
|
||||
"description": "Migration guidance, including multi-tenant migration behavior"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/helm/charts/onyx/values-lite.yaml",
|
||||
"description": "Lite deployment Helm values and service assumptions"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
|
||||
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,9 +342,9 @@ visible text in the DOM (e.g., `title`, `description`, `label`) should be typed
|
||||
|
||||
```typescript
|
||||
import type { RichStr } from "@opal/types";
|
||||
import { resolveStr } from "@opal/components/text/InlineMarkdown";
|
||||
import { Text } from "@opal/components";
|
||||
|
||||
// ✅ Good — new components accept string | RichStr
|
||||
// ✅ Good — new components accept string | RichStr and render via Text
|
||||
interface InfoCardProps {
|
||||
title: string | RichStr;
|
||||
description?: string | RichStr;
|
||||
@@ -353,9 +353,9 @@ interface InfoCardProps {
|
||||
function InfoCard({ title, description }: InfoCardProps) {
|
||||
return (
|
||||
<div>
|
||||
<Text font="main-ui-action">{resolveStr(title)}</Text>
|
||||
<Text font="main-ui-action">{title}</Text>
|
||||
{description && (
|
||||
<Text font="secondary-body" color="text-03">{resolveStr(description)}</Text>
|
||||
<Text font="secondary-body" color="text-03">{description}</Text>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -4,11 +4,15 @@ import {
|
||||
Interactive,
|
||||
type InteractiveStatelessProps,
|
||||
} from "@opal/core";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
RichStr,
|
||||
} from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { cn } from "@opal/utils";
|
||||
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -18,13 +22,13 @@ import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
type ButtonContentProps =
|
||||
| {
|
||||
icon?: IconFunctionComponent;
|
||||
children: string;
|
||||
children: string | RichStr;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
responsiveHideText?: never;
|
||||
}
|
||||
| {
|
||||
icon: IconFunctionComponent;
|
||||
children?: string;
|
||||
children?: string | RichStr;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
responsiveHideText?: boolean;
|
||||
};
|
||||
@@ -69,15 +73,24 @@ function Button({
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
<span
|
||||
className={cn(
|
||||
"whitespace-nowrap",
|
||||
isLarge ? "font-main-ui-body " : "font-secondary-body",
|
||||
responsiveHideText && "hidden md:inline"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</span>
|
||||
responsiveHideText ? (
|
||||
<span className="hidden md:inline whitespace-nowrap">
|
||||
<Text
|
||||
font={isLarge ? "main-ui-body" : "secondary-body"}
|
||||
color="inherit"
|
||||
>
|
||||
{children}
|
||||
</Text>
|
||||
</span>
|
||||
) : (
|
||||
<Text
|
||||
font={isLarge ? "main-ui-body" : "secondary-body"}
|
||||
color="inherit"
|
||||
nowrap
|
||||
>
|
||||
{children}
|
||||
</Text>
|
||||
)
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
|
||||
@@ -4,7 +4,8 @@ import {
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { iconWrapper } from "@opal/components/buttons/icon-wrapper";
|
||||
@@ -16,12 +17,12 @@ import { Button } from "@opal/components/buttons/button/components";
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface FilterButtonProps
|
||||
extends Omit<InteractiveStatefulProps, "variant" | "state"> {
|
||||
extends Omit<InteractiveStatefulProps, "variant" | "state" | "children"> {
|
||||
/** Left icon — always visible. */
|
||||
icon: IconFunctionComponent;
|
||||
|
||||
/** Label text between icon and trailing indicator. */
|
||||
children: string;
|
||||
children: string | RichStr;
|
||||
|
||||
/** Whether the filter has an active selection. @default false */
|
||||
active?: boolean;
|
||||
@@ -68,9 +69,9 @@ function FilterButton({
|
||||
<Interactive.Container type="button">
|
||||
<div className="interactive-foreground flex flex-row items-center gap-1">
|
||||
{iconWrapper(Icon, "lg", true)}
|
||||
<span className="whitespace-nowrap font-main-ui-action">
|
||||
<Text font="main-ui-action" color="inherit" nowrap>
|
||||
{children}
|
||||
</span>
|
||||
</Text>
|
||||
<div style={{ visibility: active ? "hidden" : "visible" }}>
|
||||
{iconWrapper(ChevronIcon, "lg", true)}
|
||||
</div>
|
||||
|
||||
@@ -4,7 +4,12 @@ import {
|
||||
type InteractiveStatefulProps,
|
||||
type InteractiveStatefulInteraction,
|
||||
} from "@opal/core";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
RichStr,
|
||||
} from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import type { InteractiveContainerRoundingVariant } from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
@@ -28,17 +33,17 @@ type OpenButtonContentProps =
|
||||
| {
|
||||
foldable: true;
|
||||
icon: IconFunctionComponent;
|
||||
children: string;
|
||||
children: string | RichStr;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon?: IconFunctionComponent;
|
||||
children: string;
|
||||
children: string | RichStr;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon: IconFunctionComponent;
|
||||
children?: string;
|
||||
children?: string | RichStr;
|
||||
};
|
||||
|
||||
type OpenButtonVariant = "select-light" | "select-heavy" | "select-tinted";
|
||||
@@ -101,14 +106,13 @@ function OpenButton({
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
<span
|
||||
className={cn(
|
||||
"whitespace-nowrap",
|
||||
isLarge ? "font-main-ui-body" : "font-secondary-body"
|
||||
)}
|
||||
<Text
|
||||
font={isLarge ? "main-ui-body" : "secondary-body"}
|
||||
color="inherit"
|
||||
nowrap
|
||||
>
|
||||
{children}
|
||||
</span>
|
||||
</Text>
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
@@ -177,7 +181,7 @@ function OpenButton({
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{resolvedTooltip}
|
||||
<Text>{resolvedTooltip}</Text>
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
|
||||
@@ -4,7 +4,12 @@ import {
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import type { ContainerSizeVariants, ExtremaSizeVariants } from "@opal/types";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
RichStr,
|
||||
} from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
@@ -26,19 +31,19 @@ type SelectButtonContentProps =
|
||||
| {
|
||||
foldable: true;
|
||||
icon: IconFunctionComponent;
|
||||
children: string;
|
||||
children: string | RichStr;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon?: IconFunctionComponent;
|
||||
children: string;
|
||||
children: string | RichStr;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon: IconFunctionComponent;
|
||||
children?: string;
|
||||
children?: string | RichStr;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
};
|
||||
|
||||
@@ -79,13 +84,10 @@ function SelectButton({
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-select-button-label",
|
||||
isLarge ? "font-main-ui-body" : "font-secondary-body"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
<span className="opal-select-button-label">
|
||||
<Text font={isLarge ? "main-ui-body" : "secondary-body"} color="inherit">
|
||||
{children}
|
||||
</Text>
|
||||
</span>
|
||||
) : null;
|
||||
|
||||
@@ -137,7 +139,7 @@ function SelectButton({
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{resolvedTooltip}
|
||||
<Text>{resolvedTooltip}</Text>
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
|
||||
@@ -4,7 +4,9 @@ import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgArrowRight, SvgChevronLeft, SvgChevronRight } from "@opal/icons";
|
||||
import { containerSizeVariants } from "@opal/shared";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
import type { RichStr, WithoutStyles } from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import { toPlainString } from "@opal/components/text/InlineMarkdown";
|
||||
import { cn } from "@opal/utils";
|
||||
import * as PopoverPrimitive from "@radix-ui/react-popover";
|
||||
import {
|
||||
@@ -38,7 +40,7 @@ interface SimplePaginationProps
|
||||
/** Hides the `currentPage/totalPages` summary text between arrows. Default: `false`. */
|
||||
hidePages?: boolean;
|
||||
/** Unit label shown after the summary (e.g. `"pages"`). Always has 4px spacing. */
|
||||
units?: string;
|
||||
units?: string | RichStr;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -63,7 +65,7 @@ interface CountPaginationProps
|
||||
/** Hides the current page number between the arrows. Default: `false`. */
|
||||
hidePages?: boolean;
|
||||
/** Unit label shown after the total count (e.g. `"items"`). Always has 4px spacing. */
|
||||
units?: string;
|
||||
units?: string | RichStr;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -331,7 +333,9 @@ function PaginationSimple({
|
||||
}: SimplePaginationProps) {
|
||||
const handleChange = (page: number) => onChange?.(page);
|
||||
|
||||
const label = `${currentPage}/${totalPages}${units ? ` ${units}` : ""}`;
|
||||
const label = `${currentPage}/${totalPages}${
|
||||
units ? ` ${toPlainString(units)}` : ""
|
||||
}`;
|
||||
|
||||
return (
|
||||
<div {...props} className="flex items-center">
|
||||
@@ -385,7 +389,16 @@ function PaginationCount({
|
||||
{rangeStart}~{rangeEnd}
|
||||
<span className={textClasses(size, "muted")}>of</span>
|
||||
{totalItems}
|
||||
{units && <span className="ml-1">{units}</span>}
|
||||
{units && (
|
||||
<span className="ml-1">
|
||||
<Text
|
||||
color="inherit"
|
||||
font={size === "sm" ? "secondary-body" : "main-ui-muted"}
|
||||
>
|
||||
{units}
|
||||
</Text>
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
|
||||
{/* Buttons: < [page] > */}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import "@opal/components/tag/styles.css";
|
||||
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -16,7 +17,7 @@ interface TagProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Tag label text. */
|
||||
title: string;
|
||||
title: string | RichStr;
|
||||
|
||||
/** Color variant. Default: `"gray"`. */
|
||||
color?: TagColor;
|
||||
@@ -51,14 +52,13 @@ function Tag({ icon: Icon, title, color = "gray", size = "sm" }: TagProps) {
|
||||
<Icon className={cn("opal-auxiliary-tag-icon", config.text)} />
|
||||
</div>
|
||||
)}
|
||||
<span
|
||||
className={cn(
|
||||
"opal-auxiliary-tag-title px-[2px]",
|
||||
size === "md" ? "font-secondary-body" : "font-figure-small-value",
|
||||
config.text
|
||||
)}
|
||||
>
|
||||
{title}
|
||||
<span className={cn("opal-auxiliary-tag-title px-[2px]", config.text)}>
|
||||
<Text
|
||||
font={size === "md" ? "secondary-body" : "figure-small-value"}
|
||||
color="inherit"
|
||||
>
|
||||
{title}
|
||||
</Text>
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -10,7 +10,7 @@ import type { RichStr } from "@opal/types";
|
||||
|
||||
const SAFE_PROTOCOL = /^https?:|^mailto:|^tel:/i;
|
||||
|
||||
const ALLOWED_ELEMENTS = ["p", "a", "strong", "em", "code", "del"];
|
||||
const ALLOWED_ELEMENTS = ["p", "br", "a", "strong", "em", "code", "del"];
|
||||
|
||||
const INLINE_COMPONENTS = {
|
||||
p: ({ children }: { children?: ReactNode }) => <>{children}</>,
|
||||
@@ -41,6 +41,11 @@ interface InlineMarkdownProps {
|
||||
}
|
||||
|
||||
export default function InlineMarkdown({ content }: InlineMarkdownProps) {
|
||||
// Convert \n to CommonMark hard line breaks (two trailing spaces + newline).
|
||||
// react-markdown renders these as <br />, which inherits the parent's
|
||||
// line-height for font-appropriate spacing.
|
||||
const normalized = content.replace(/\n/g, " \n");
|
||||
|
||||
return (
|
||||
<ReactMarkdown
|
||||
components={INLINE_COMPONENTS}
|
||||
@@ -48,7 +53,7 @@ export default function InlineMarkdown({ content }: InlineMarkdownProps) {
|
||||
unwrapDisallowed
|
||||
remarkPlugins={[remarkGfm]}
|
||||
>
|
||||
{content}
|
||||
{normalized}
|
||||
</ReactMarkdown>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -90,15 +90,15 @@ import { markdown } from "@opal/utils";
|
||||
</Text>
|
||||
```
|
||||
|
||||
Supported syntax: `**bold**`, `*italic*`, `` `code` ``, `[link](url)`, `~~strikethrough~~`.
|
||||
Supported syntax: `**bold**`, `*italic*`, `` `code` ``, `[link](url)`, `~~strikethrough~~`, `\n` (newline → `<br />`).
|
||||
|
||||
Markdown rendering uses `react-markdown` internally, restricted to inline elements only.
|
||||
`http(s)` links open in a new tab; `mailto:` and `tel:` links open natively. Inline code
|
||||
inherits the parent font size and switches to the monospace family.
|
||||
|
||||
**Note:** This is inline-only markdown. Multi-paragraph content (`"Hello\n\nWorld"`) will
|
||||
collapse into a single run of text since paragraph wrappers are stripped. For block-level
|
||||
markdown, use `MinimalMarkdown` instead.
|
||||
Newlines (`\n`) are converted to `<br />` elements that inherit the parent's line-height,
|
||||
so line spacing is proportional to the font size. For full block-level markdown (code blocks,
|
||||
headings, lists), use `MinimalMarkdown` instead.
|
||||
|
||||
### Using `RichStr` in component props
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ type TextFont =
|
||||
| "figure-keystroke";
|
||||
|
||||
type TextColor =
|
||||
| "inherit"
|
||||
| "text-01"
|
||||
| "text-02"
|
||||
| "text-03"
|
||||
@@ -60,6 +61,9 @@ interface TextProps
|
||||
/** Prevent text wrapping. */
|
||||
nowrap?: boolean;
|
||||
|
||||
/** Truncate text to N lines with ellipsis. `1` uses simple truncation; `2+` uses `-webkit-line-clamp`. */
|
||||
maxLines?: number;
|
||||
|
||||
/** Plain string or `markdown()` for inline markdown. */
|
||||
children?: string | RichStr;
|
||||
}
|
||||
@@ -89,7 +93,8 @@ const FONT_CONFIG: Record<TextFont, string> = {
|
||||
"figure-keystroke": "font-figure-keystroke",
|
||||
};
|
||||
|
||||
const COLOR_CONFIG: Record<TextColor, string> = {
|
||||
const COLOR_CONFIG: Record<TextColor, string | null> = {
|
||||
inherit: null,
|
||||
"text-01": "text-text-01",
|
||||
"text-02": "text-text-02",
|
||||
"text-03": "text-text-03",
|
||||
@@ -115,17 +120,29 @@ function Text({
|
||||
color = "text-04",
|
||||
as: Tag = "span",
|
||||
nowrap,
|
||||
maxLines,
|
||||
children,
|
||||
...rest
|
||||
}: TextProps) {
|
||||
const resolvedClassName = cn(
|
||||
FONT_CONFIG[font],
|
||||
COLOR_CONFIG[color],
|
||||
nowrap && "whitespace-nowrap"
|
||||
nowrap && "whitespace-nowrap",
|
||||
maxLines === 1 && "truncate",
|
||||
maxLines && maxLines > 1 && "overflow-hidden"
|
||||
);
|
||||
|
||||
const style =
|
||||
maxLines && maxLines > 1
|
||||
? ({
|
||||
display: "-webkit-box",
|
||||
WebkitBoxOrient: "vertical",
|
||||
WebkitLineClamp: maxLines,
|
||||
} as React.CSSProperties)
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<Tag {...rest} className={resolvedClassName}>
|
||||
<Tag {...rest} className={resolvedClassName} style={style}>
|
||||
{children && resolveStr(children)}
|
||||
</Tag>
|
||||
);
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/core/animations/styles.css";
|
||||
import React, { createContext, useContext, useState, useCallback } from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/core/disabled/styles.css";
|
||||
import React, { createContext, useContext } from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
|
||||
22
web/lib/opal/src/icons/bifrost.tsx
Normal file
22
web/lib/opal/src/icons/bifrost.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 37 46"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={cn(className, "text-[#33C19E] dark:text-white")}
|
||||
{...props}
|
||||
>
|
||||
<title>Bifrost</title>
|
||||
<path
|
||||
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgBifrost;
|
||||
@@ -24,6 +24,7 @@ export { default as SvgAzure } from "@opal/icons/azure";
|
||||
export { default as SvgBarChart } from "@opal/icons/bar-chart";
|
||||
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
|
||||
export { default as SvgBell } from "@opal/icons/bell";
|
||||
export { default as SvgBifrost } from "@opal/icons/bifrost";
|
||||
export { default as SvgBlocks } from "@opal/icons/blocks";
|
||||
export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBookmark } from "@opal/icons/bookmark";
|
||||
|
||||
@@ -3,7 +3,11 @@
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
import type { ContainerSizeVariants } from "@opal/types";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import {
|
||||
resolveStr,
|
||||
toPlainString,
|
||||
} from "@opal/components/text/InlineMarkdown";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -35,10 +39,10 @@ interface ContentLgProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
title: string | RichStr;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string;
|
||||
description?: string | RichStr;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
@@ -96,18 +100,18 @@ function ContentLg({
|
||||
ref,
|
||||
}: ContentLgProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const [editValue, setEditValue] = useState(toPlainString(title));
|
||||
|
||||
const config = CONTENT_LG_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(title);
|
||||
setEditValue(toPlainString(title));
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
if (value && value !== toPlainString(title)) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
@@ -157,7 +161,7 @@ function ContentLg({
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(title);
|
||||
setEditValue(toPlainString(title));
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
@@ -174,9 +178,9 @@ function ContentLg({
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={title}
|
||||
title={toPlainString(title)}
|
||||
>
|
||||
{title}
|
||||
{resolveStr(title)}
|
||||
</span>
|
||||
)}
|
||||
|
||||
@@ -199,9 +203,9 @@ function ContentLg({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
{description && toPlainString(description) && (
|
||||
<div className="opal-content-lg-description font-secondary-body text-text-03">
|
||||
{description}
|
||||
{resolveStr(description)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -7,7 +7,11 @@ import SvgAlertCircle from "@opal/icons/alert-circle";
|
||||
import SvgAlertTriangle from "@opal/icons/alert-triangle";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import SvgXOctagon from "@opal/icons/x-octagon";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import {
|
||||
resolveStr,
|
||||
toPlainString,
|
||||
} from "@opal/components/text/InlineMarkdown";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
@@ -25,7 +29,6 @@ interface ContentMdPresetConfig {
|
||||
iconColorClass: string;
|
||||
titleFont: string;
|
||||
lineHeight: string;
|
||||
gap: string;
|
||||
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
|
||||
editButtonSize: ContainerSizeVariants;
|
||||
editButtonPadding: string;
|
||||
@@ -41,10 +44,10 @@ interface ContentMdProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
title: string | RichStr;
|
||||
|
||||
/** Optional description text below the title. */
|
||||
description?: string;
|
||||
description?: string | RichStr;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
@@ -70,6 +73,15 @@ interface ContentMdProps {
|
||||
/** When `true`, the title color hooks into `Interactive`'s `--interactive-foreground` variable. */
|
||||
withInteractive?: boolean;
|
||||
|
||||
/** Optional class name applied to the title element. */
|
||||
titleClassName?: string;
|
||||
|
||||
/** Optional class name applied to the icon element. */
|
||||
iconClassName?: string;
|
||||
|
||||
/** Content rendered below the description, indented to align with it. */
|
||||
bottomChildren?: React.ReactNode;
|
||||
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
}
|
||||
@@ -85,7 +97,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
iconColorClass: "text-text-04",
|
||||
titleFont: "font-main-content-emphasis",
|
||||
lineHeight: "1.5rem",
|
||||
gap: "0.125rem",
|
||||
editButtonSize: "sm",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-content-muted",
|
||||
@@ -98,7 +109,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
iconColorClass: "text-text-03",
|
||||
titleFont: "font-main-ui-action",
|
||||
lineHeight: "1.25rem",
|
||||
gap: "0.25rem",
|
||||
editButtonSize: "xs",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-ui-muted",
|
||||
@@ -111,7 +121,6 @@ const CONTENT_MD_PRESETS: Record<ContentMdSizePreset, ContentMdPresetConfig> = {
|
||||
iconColorClass: "text-text-04",
|
||||
titleFont: "font-secondary-action",
|
||||
lineHeight: "1rem",
|
||||
gap: "0.125rem",
|
||||
editButtonSize: "2xs",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-secondary-action",
|
||||
@@ -146,22 +155,25 @@ function ContentMd({
|
||||
tag,
|
||||
sizePreset = "main-ui",
|
||||
withInteractive,
|
||||
titleClassName,
|
||||
iconClassName,
|
||||
bottomChildren,
|
||||
ref,
|
||||
}: ContentMdProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const [editValue, setEditValue] = useState(toPlainString(title));
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = CONTENT_MD_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(title);
|
||||
setEditValue(toPlainString(title));
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
if (value && value !== toPlainString(title)) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
@@ -170,7 +182,6 @@ function ContentMd({
|
||||
ref={ref}
|
||||
className="opal-content-md"
|
||||
data-interactive={withInteractive || undefined}
|
||||
style={{ gap: config.gap }}
|
||||
>
|
||||
<div
|
||||
className="opal-content-md-header"
|
||||
@@ -185,7 +196,11 @@ function ContentMd({
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className={cn("opal-content-md-icon", config.iconColorClass)}
|
||||
className={cn(
|
||||
"opal-content-md-icon",
|
||||
config.iconColorClass,
|
||||
iconClassName
|
||||
)}
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
@@ -215,7 +230,7 @@ function ContentMd({
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(title);
|
||||
setEditValue(toPlainString(title));
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
@@ -228,13 +243,14 @@ function ContentMd({
|
||||
"opal-content-md-title",
|
||||
config.titleFont,
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
editable && "cursor-pointer",
|
||||
titleClassName
|
||||
)}
|
||||
title={title}
|
||||
title={toPlainString(title)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{title}
|
||||
{resolveStr(title)}
|
||||
</span>
|
||||
)}
|
||||
|
||||
@@ -288,12 +304,19 @@ function ContentMd({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
{description && toPlainString(description) && (
|
||||
<div
|
||||
className="opal-content-md-description font-secondary-body text-text-03"
|
||||
style={Icon ? { paddingLeft: config.descriptionIndent } : undefined}
|
||||
>
|
||||
{description}
|
||||
{resolveStr(description)}
|
||||
</div>
|
||||
)}
|
||||
{bottomChildren && (
|
||||
<div
|
||||
style={Icon ? { paddingLeft: config.descriptionIndent } : undefined}
|
||||
>
|
||||
{bottomChildren}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import {
|
||||
resolveStr,
|
||||
toPlainString,
|
||||
} from "@opal/components/text/InlineMarkdown";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -30,7 +34,7 @@ interface ContentSmProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text (read-only — editing is not supported). */
|
||||
title: string;
|
||||
title: string | RichStr;
|
||||
|
||||
/** Size preset. Default: `"main-ui"`. */
|
||||
sizePreset?: ContentSmSizePreset;
|
||||
@@ -118,9 +122,9 @@ function ContentSm({
|
||||
<span
|
||||
className={cn("opal-content-sm-title", config.titleFont)}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={title}
|
||||
title={toPlainString(title)}
|
||||
>
|
||||
{title}
|
||||
{resolveStr(title)}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -3,7 +3,11 @@
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
import type { ContainerSizeVariants } from "@opal/types";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import {
|
||||
resolveStr,
|
||||
toPlainString,
|
||||
} from "@opal/components/text/InlineMarkdown";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useState } from "react";
|
||||
|
||||
@@ -41,10 +45,10 @@ interface ContentXlProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
title: string | RichStr;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string;
|
||||
description?: string | RichStr;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
@@ -116,18 +120,18 @@ function ContentXl({
|
||||
ref,
|
||||
}: ContentXlProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const [editValue, setEditValue] = useState(toPlainString(title));
|
||||
|
||||
const config = CONTENT_XL_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(title);
|
||||
setEditValue(toPlainString(title));
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
if (value && value !== toPlainString(title)) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
@@ -214,7 +218,7 @@ function ContentXl({
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(title);
|
||||
setEditValue(toPlainString(title));
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
@@ -231,9 +235,9 @@ function ContentXl({
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
title={title}
|
||||
title={toPlainString(title)}
|
||||
>
|
||||
{title}
|
||||
{resolveStr(title)}
|
||||
</span>
|
||||
)}
|
||||
|
||||
@@ -256,9 +260,9 @@ function ContentXl({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
{description && toPlainString(description) && (
|
||||
<div className="opal-content-xl-description font-secondary-body text-text-03">
|
||||
{description}
|
||||
{resolveStr(description)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,3 +1,39 @@
|
||||
// ---------------------------------------------------------------------------
|
||||
// NOTE (@raunakab): Why Content uses resolveStr() instead of <Text>
|
||||
//
|
||||
// Content sub-components (ContentXl, ContentLg, ContentMd, ContentSm) render
|
||||
// titles and descriptions inside styled <span> elements that carry CSS classes
|
||||
// (e.g., `.opal-content-md-title`) for:
|
||||
//
|
||||
// 1. Truncation — `-webkit-box` + `-webkit-line-clamp` for single-line
|
||||
// clamping with ellipsis. This requires the text to be a DIRECT child
|
||||
// of the `-webkit-box` element. Wrapping it in a child <span> (which
|
||||
// is what <Text> renders) breaks the clamping behavior.
|
||||
//
|
||||
// 2. Pixel-exact sizing — the wrapper <span> has an explicit `height`
|
||||
// matching the font's `line-height`. Adding a child <Text> <span>
|
||||
// inside creates a double-span where the inner element's line-height
|
||||
// conflicts with the outer element's height, causing a ~4px vertical
|
||||
// offset.
|
||||
//
|
||||
// 3. Interactive color overrides — CSS selectors like
|
||||
// `.opal-content-md[data-interactive] .opal-content-md-title` set
|
||||
// `color: var(--interactive-foreground)`. <Text> with `color="inherit"`
|
||||
// can inherit this, but <Text> with any explicit color prop overrides
|
||||
// it. And the wrapper <span> needs the CSS class for the selector to
|
||||
// match — removing it breaks the cascade.
|
||||
//
|
||||
// 4. Horizontal padding — the title CSS class applies `padding: 0 0.125rem`
|
||||
// (2px). Since <Text> uses WithoutStyles (no className/style), this
|
||||
// padding cannot be applied to <Text> directly. A wrapper <div> was
|
||||
// attempted but introduced additional layout conflicts.
|
||||
//
|
||||
// For these reasons, Content uses `resolveStr()` from InlineMarkdown.tsx to
|
||||
// handle `string | RichStr` rendering. `resolveStr()` returns a ReactNode
|
||||
// that slots directly into the existing single <span> — no extra wrapper,
|
||||
// no layout conflicts, pixel-exact match with main.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
import "@opal/layouts/content/styles.css";
|
||||
import {
|
||||
ContentSm,
|
||||
@@ -17,7 +53,7 @@ import {
|
||||
type ContentMdProps,
|
||||
} from "@opal/layouts/content/ContentMd";
|
||||
import type { TagProps } from "@opal/components/tag/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import { widthVariants } from "@opal/shared";
|
||||
import type { ExtremaSizeVariants } from "@opal/types";
|
||||
|
||||
@@ -39,10 +75,10 @@ interface ContentBaseProps {
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
title: string | RichStr;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string;
|
||||
description?: string | RichStr;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
@@ -67,6 +103,12 @@ interface ContentBaseProps {
|
||||
|
||||
/** Ref forwarded to the root `<div>` of the resolved layout. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
|
||||
/** Optional class name applied to the icon element. */
|
||||
iconClassName?: string;
|
||||
|
||||
/** Content rendered below the description, indented to align with it (MdContent only). */
|
||||
bottomChildren?: React.ReactNode;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -102,6 +144,8 @@ type MdContentProps = ContentBaseProps & {
|
||||
auxIcon?: "info-gray" | "info-blue" | "warning" | "error";
|
||||
/** Tag rendered beside the title. */
|
||||
tag?: TagProps;
|
||||
/** Optional class name applied to the title element. */
|
||||
titleClassName?: string;
|
||||
};
|
||||
|
||||
/** ContentSm does not support descriptions or inline editing. */
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import type { IconFunctionComponent, RichStr } from "@opal/types";
|
||||
import { Text } from "@opal/components";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -9,10 +10,10 @@ interface IllustrationContentProps {
|
||||
illustration?: IconFunctionComponent;
|
||||
|
||||
/** Main title text, center-aligned. Uses `font-main-content-emphasis`. */
|
||||
title: string;
|
||||
title: string | RichStr;
|
||||
|
||||
/** Optional description below the title, center-aligned. Uses `font-secondary-body`. */
|
||||
description?: string;
|
||||
description?: string | RichStr;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -68,9 +69,13 @@ function IllustrationContent({
|
||||
/>
|
||||
)}
|
||||
<div className="flex flex-col items-center text-center">
|
||||
<p className="font-main-content-emphasis text-text-04">{title}</p>
|
||||
<Text font="main-content-emphasis" color="text-04" as="p">
|
||||
{title}
|
||||
</Text>
|
||||
{description && (
|
||||
<p className="font-secondary-body text-text-03">{description}</p>
|
||||
<Text font="secondary-body" color="text-03" as="p">
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -30,8 +30,11 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
|
||||
import { useBillingInformation } from "@/hooks/useBillingInformation";
|
||||
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTES.API_KEYS;
|
||||
@@ -44,6 +47,11 @@ function Main() {
|
||||
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const canCreateKeys = useCloudSubscription();
|
||||
const { data: billingData } = useBillingInformation();
|
||||
const isTrialing =
|
||||
billingData !== undefined &&
|
||||
hasActiveSubscription(billingData) &&
|
||||
billingData.status === BillingStatus.TRIALING;
|
||||
|
||||
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
|
||||
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
|
||||
@@ -75,6 +83,16 @@ function Main() {
|
||||
|
||||
const introSection = (
|
||||
<div className="flex flex-col items-start gap-4">
|
||||
{isTrialing && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
close={false}
|
||||
className="w-full"
|
||||
text="Upgrade to a paid plan to create API keys."
|
||||
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
)}
|
||||
<Text as="p">
|
||||
API Keys allow you to access Onyx APIs programmatically.
|
||||
{canCreateKeys
|
||||
@@ -85,23 +103,9 @@ function Main() {
|
||||
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
|
||||
Create API Key
|
||||
</CreateButton>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Text as="p" text04>
|
||||
Upgrade to a paid plan to create API keys.
|
||||
</Text>
|
||||
<Button
|
||||
variant="none"
|
||||
prominence="tertiary"
|
||||
size="2xs"
|
||||
icon={SvgInfo}
|
||||
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
</div>
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
</div>
|
||||
)}
|
||||
) : isTrialing ? (
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
|
||||
|
||||
387
web/src/app/admin/billing/page.test.tsx
Normal file
387
web/src/app/admin/billing/page.test.tsx
Normal file
@@ -0,0 +1,387 @@
|
||||
/**
|
||||
* Tests for BillingPage handleBillingReturn retry logic.
|
||||
*
|
||||
* The retry logic retries claimLicense up to 3 times with 2s backoff
|
||||
* when returning from a Stripe checkout session. This prevents the user
|
||||
* from getting stranded when the Stripe webhook fires concurrently with
|
||||
* the browser redirect and the license isn't ready yet.
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, waitFor } from "@tests/setup/test-utils";
|
||||
import { act } from "@testing-library/react";
|
||||
|
||||
// ---- Stable mock objects (must be named with mock* prefix for jest hoisting) ----
|
||||
// useRouter and useSearchParams must return the SAME reference each call, otherwise
|
||||
// React's useEffect sees them as changed and re-runs the effect on every render.
|
||||
const mockRouter = {
|
||||
replace: jest.fn() as jest.Mock,
|
||||
refresh: jest.fn() as jest.Mock,
|
||||
};
|
||||
const mockSearchParams = {
|
||||
get: jest.fn() as jest.Mock,
|
||||
};
|
||||
const mockClaimLicense = jest.fn() as jest.Mock;
|
||||
const mockRefreshBilling = jest.fn() as jest.Mock;
|
||||
const mockRefreshLicense = jest.fn() as jest.Mock;
|
||||
|
||||
// ---- Mocks ----
|
||||
|
||||
jest.mock("next/navigation", () => ({
|
||||
useRouter: () => mockRouter,
|
||||
useSearchParams: () => mockSearchParams,
|
||||
}));
|
||||
|
||||
jest.mock("@/layouts/settings-layouts", () => ({
|
||||
Root: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="settings-root">{children}</div>
|
||||
),
|
||||
Header: () => <div data-testid="settings-header" />,
|
||||
Body: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="settings-body">{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/layouts/general-layouts", () => ({
|
||||
Section: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@opal/icons", () => ({
|
||||
SvgArrowUpCircle: () => <svg />,
|
||||
SvgWallet: () => <svg />,
|
||||
}));
|
||||
|
||||
jest.mock("./PlansView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="plans-view" />,
|
||||
}));
|
||||
jest.mock("./CheckoutView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="checkout-view" />,
|
||||
}));
|
||||
jest.mock("./BillingDetailsView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="billing-details-view" />,
|
||||
}));
|
||||
jest.mock("./LicenseActivationCard", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="license-activation-card" />,
|
||||
}));
|
||||
|
||||
jest.mock("@/refresh-components/messages/Message", () => ({
|
||||
__esModule: true,
|
||||
default: ({
|
||||
text,
|
||||
description,
|
||||
onClose,
|
||||
}: {
|
||||
text: string;
|
||||
description?: string;
|
||||
onClose?: () => void;
|
||||
}) => (
|
||||
<div data-testid="activating-banner">
|
||||
<span data-testid="activating-banner-text">{text}</span>
|
||||
{description && (
|
||||
<span data-testid="activating-banner-description">{description}</span>
|
||||
)}
|
||||
{onClose && (
|
||||
<button data-testid="activating-banner-close" onClick={onClose}>
|
||||
Close
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/billing", () => ({
|
||||
useBillingInformation: jest.fn(),
|
||||
useLicense: jest.fn(),
|
||||
hasActiveSubscription: jest.fn().mockReturnValue(false),
|
||||
claimLicense: (...args: unknown[]) => mockClaimLicense(...args),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/constants", () => ({
|
||||
NEXT_PUBLIC_CLOUD_ENABLED: false,
|
||||
}));
|
||||
|
||||
// ---- Import after mocks ----
|
||||
import BillingPage from "./page";
|
||||
import { useBillingInformation, useLicense } from "@/lib/billing";
|
||||
|
||||
// ---- Test helpers ----
|
||||
|
||||
function setupHooks() {
|
||||
(useBillingInformation as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
refresh: mockRefreshBilling,
|
||||
});
|
||||
(useLicense as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
refresh: mockRefreshLicense,
|
||||
});
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
describe("BillingPage — handleBillingReturn retry logic", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
jest.useFakeTimers();
|
||||
setupHooks();
|
||||
// Default: no billing-return params
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
// Clear any activating state from prior tests
|
||||
sessionStorage.clear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
test("calls claimLicense once and refreshes on first-attempt success", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_test_123" : null
|
||||
);
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith("cs_test_123");
|
||||
});
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
// URL cleaned up after checkout return
|
||||
expect(mockRouter.replace).toHaveBeenCalledWith("/admin/billing", {
|
||||
scroll: false,
|
||||
});
|
||||
});
|
||||
|
||||
test("retries after first failure and succeeds on second attempt", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_retry_test" : null
|
||||
);
|
||||
mockClaimLicense
|
||||
.mockRejectedValueOnce(new Error("License not ready yet"))
|
||||
.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
// On eventual success, router and billing should be refreshed
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("retries all 3 times then navigates to details even on total failure", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_all_fail" : null
|
||||
);
|
||||
// All 3 attempts fail
|
||||
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
|
||||
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
// User stays on plans view with the activating banner
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("plans-view")).toBeInTheDocument();
|
||||
});
|
||||
// refreshBilling still fires so billing state is up to date
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
// Failure is logged
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Failed to sync license after billing return"),
|
||||
expect.any(Error)
|
||||
);
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("calls claimLicense without session_id on portal_return", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "portal_return" ? "true" : null
|
||||
);
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
|
||||
// No session_id for portal returns — called with undefined
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
|
||||
});
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("does not call claimLicense when no billing-return params present", async () => {
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(mockClaimLicense).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("shows activating banner and sets sessionStorage on 3x retry failure", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_all_fail" : null
|
||||
);
|
||||
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
|
||||
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.getByTestId("activating-banner-text")).toHaveTextContent(
|
||||
"Your license is still activating"
|
||||
);
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).not.toBeNull();
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("banner not rendered when no activating state", async () => {
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("banner shown on mount when sessionStorage key is set and not expired", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush React effects — banner is visible from lazy state init, no timer advancement needed
|
||||
await act(async () => {});
|
||||
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("banner not shown on mount when sessionStorage key is expired", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() - 1000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
test("poll calls claimLicense after 15s and clears banner on success", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
// Poll attempt succeeds
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush effects — banner visible from lazy state init
|
||||
await act(async () => {});
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
|
||||
// Advance past one poll interval (15s)
|
||||
await act(async () => {
|
||||
await jest.advanceTimersByTimeAsync(15_000);
|
||||
});
|
||||
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
expect(mockRefreshLicense).toHaveBeenCalled();
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("close button removes banner and clears sessionStorage", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush effects — banner visible from lazy state init
|
||||
await act(async () => {});
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
|
||||
const closeButton = screen.getByTestId("activating-banner-close");
|
||||
await act(async () => {
|
||||
closeButton.click();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
} from "@/lib/billing";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
|
||||
import PlansView from "./PlansView";
|
||||
import CheckoutView from "./CheckoutView";
|
||||
@@ -24,6 +25,9 @@ import BillingDetailsView from "./BillingDetailsView";
|
||||
import LicenseActivationCard from "./LicenseActivationCard";
|
||||
import "./billing.css";
|
||||
|
||||
// sessionStorage key: value is a unix-ms expiry timestamp
|
||||
const BILLING_ACTIVATING_KEY = "billing_license_activating_until";
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Types
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -105,6 +109,7 @@ export default function BillingPage() {
|
||||
const [transitionType, setTransitionType] = useState<
|
||||
"expand" | "collapse" | "fade"
|
||||
>("fade");
|
||||
const [isActivating, setIsActivating] = useState<boolean>(false);
|
||||
|
||||
const {
|
||||
data: billingData,
|
||||
@@ -155,6 +160,17 @@ export default function BillingPage() {
|
||||
view,
|
||||
]);
|
||||
|
||||
// Read activating state from sessionStorage after mount (avoids SSR hydration mismatch)
|
||||
useEffect(() => {
|
||||
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
|
||||
if (!raw) return;
|
||||
if (Number(raw) > Date.now()) {
|
||||
setIsActivating(true);
|
||||
} else {
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Show license activation card when there's a Stripe error
|
||||
useEffect(() => {
|
||||
if (hasStripeError && !showLicenseActivationInput) {
|
||||
@@ -172,24 +188,96 @@ export default function BillingPage() {
|
||||
|
||||
router.replace("/admin/billing", { scroll: false });
|
||||
|
||||
let cancelled = false;
|
||||
|
||||
const handleBillingReturn = async () => {
|
||||
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
try {
|
||||
// After checkout, exchange session_id for license; after portal, re-sync license
|
||||
await claimLicense(sessionId ?? undefined);
|
||||
refreshLicense();
|
||||
// Refresh the page to update settings (including ee_features_enabled)
|
||||
router.refresh();
|
||||
// Navigate to billing details now that the license is active
|
||||
changeView("details");
|
||||
} catch (error) {
|
||||
console.error("Failed to sync license after billing return:", error);
|
||||
// Retry up to 3 times with 2s backoff. The license may not be available
|
||||
// immediately if the Stripe webhook hasn't finished processing yet
|
||||
// (redirect and webhook fire nearly simultaneously).
|
||||
let lastError: Error | null = null;
|
||||
for (let attempt = 0; attempt < 3; attempt++) {
|
||||
if (cancelled) return;
|
||||
try {
|
||||
// After checkout, exchange session_id for license; after portal, re-sync license
|
||||
await claimLicense(sessionId ?? undefined);
|
||||
if (cancelled) return;
|
||||
refreshLicense();
|
||||
// Refresh the page to update settings (including ee_features_enabled)
|
||||
router.refresh();
|
||||
// Navigate to billing details now that the license is active
|
||||
changeView("details");
|
||||
lastError = null;
|
||||
break;
|
||||
} catch (err) {
|
||||
lastError = err instanceof Error ? err : new Error("Unknown error");
|
||||
if (attempt < 2) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cancelled) return;
|
||||
if (lastError) {
|
||||
console.error(
|
||||
"Failed to sync license after billing return:",
|
||||
lastError
|
||||
);
|
||||
// Show an activating banner on the plans view and keep retrying in the background.
|
||||
sessionStorage.setItem(
|
||||
BILLING_ACTIVATING_KEY,
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
setIsActivating(true);
|
||||
changeView("plans");
|
||||
}
|
||||
}
|
||||
refreshBilling();
|
||||
if (!cancelled) refreshBilling();
|
||||
};
|
||||
handleBillingReturn();
|
||||
}, [searchParams, router, refreshBilling, refreshLicense]);
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
// changeView intentionally omitted: it only calls stable state setters and the
|
||||
// effect runs at most once (when session_id/portal_return params are present).
|
||||
}, [searchParams, router, refreshBilling, refreshLicense]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
// Poll every 15s while activating, up to 2 minutes, to detect when the license arrives.
|
||||
useEffect(() => {
|
||||
if (!isActivating) return;
|
||||
|
||||
let requestInFlight = false;
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
if (requestInFlight) return;
|
||||
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
|
||||
if (!raw || Number(raw) <= Date.now()) {
|
||||
// Expired — stop immediately without waiting for React cleanup
|
||||
clearInterval(intervalId);
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
return;
|
||||
}
|
||||
requestInFlight = true;
|
||||
try {
|
||||
await claimLicense(undefined);
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
refreshLicense();
|
||||
refreshBilling();
|
||||
router.refresh();
|
||||
changeView("details");
|
||||
} catch (err) {
|
||||
// License not ready yet — keep polling. Log so unexpected failures
|
||||
// (network errors, 500s) are distinguishable from expected 404s.
|
||||
console.debug("License activation poll: will retry", err);
|
||||
} finally {
|
||||
requestInFlight = false;
|
||||
}
|
||||
}, 15_000);
|
||||
|
||||
return () => clearInterval(intervalId);
|
||||
}, [isActivating]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const handleRefresh = async () => {
|
||||
await Promise.all([
|
||||
@@ -386,6 +474,22 @@ export default function BillingPage() {
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<div className="flex flex-col items-center gap-6">
|
||||
{isActivating && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
large
|
||||
text="Your license is still activating"
|
||||
description="Your license is being processed. You'll be taken to billing details automatically once confirmed."
|
||||
icon
|
||||
close
|
||||
onClose={() => {
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
}}
|
||||
className="w-full"
|
||||
/>
|
||||
)}
|
||||
{renderContent()}
|
||||
{renderFooter()}
|
||||
</div>
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo } from "react";
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import useSWR from "swr";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
@@ -17,9 +18,16 @@ import {
|
||||
ImageGenerationConfigView,
|
||||
setDefaultImageGenerationConfig,
|
||||
unsetDefaultImageGenerationConfig,
|
||||
deleteImageGenerationConfig,
|
||||
} from "@/lib/configuration/imageConfigurationService";
|
||||
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
export default function ImageGenerationContent() {
|
||||
const {
|
||||
@@ -47,6 +55,11 @@ export default function ImageGenerationContent() {
|
||||
);
|
||||
const [editConfig, setEditConfig] =
|
||||
useState<ImageGenerationConfigView | null>(null);
|
||||
const [disconnectProvider, setDisconnectProvider] =
|
||||
useState<ImageProvider | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
|
||||
const connectedProviderIds = useMemo(() => {
|
||||
return new Set(configs.map((c) => c.image_provider_id));
|
||||
@@ -115,6 +128,29 @@ export default function ImageGenerationContent() {
|
||||
modal.toggle(true);
|
||||
};
|
||||
|
||||
const handleDisconnect = async () => {
|
||||
if (!disconnectProvider) return;
|
||||
try {
|
||||
// If a replacement was selected (not "No Default"), activate it first
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
await setDefaultImageGenerationConfig(replacementProviderId);
|
||||
}
|
||||
|
||||
await deleteImageGenerationConfig(disconnectProvider.image_provider_id);
|
||||
toast.success(`${disconnectProvider.title} disconnected`);
|
||||
refetchConfigs();
|
||||
refetchProviders();
|
||||
} catch (error) {
|
||||
console.error("Failed to disconnect image generation provider:", error);
|
||||
toast.error(
|
||||
error instanceof Error ? error.message : "Failed to disconnect"
|
||||
);
|
||||
} finally {
|
||||
setDisconnectProvider(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleModalSuccess = () => {
|
||||
toast.success("Provider configured successfully");
|
||||
setEditConfig(null);
|
||||
@@ -130,6 +166,36 @@ export default function ImageGenerationContent() {
|
||||
);
|
||||
}
|
||||
|
||||
// Compute replacement options when disconnecting an active provider
|
||||
const isDisconnectingDefault =
|
||||
disconnectProvider &&
|
||||
defaultConfig?.image_provider_id === disconnectProvider.image_provider_id;
|
||||
|
||||
// Group connected replacement models by provider (excluding the model being disconnected)
|
||||
const replacementGroups = useMemo(() => {
|
||||
if (!disconnectProvider) return [];
|
||||
return IMAGE_PROVIDER_GROUPS.map((group) => ({
|
||||
...group,
|
||||
providers: group.providers.filter(
|
||||
(p) =>
|
||||
p.image_provider_id !== disconnectProvider.image_provider_id &&
|
||||
connectedProviderIds.has(p.image_provider_id)
|
||||
),
|
||||
})).filter((g) => g.providers.length > 0);
|
||||
}, [disconnectProvider, connectedProviderIds]);
|
||||
|
||||
const needsReplacement = !!isDisconnectingDefault;
|
||||
const hasReplacements = replacementGroups.length > 0;
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && !replacementProviderId && hasReplacements) {
|
||||
const firstGroup = replacementGroups[0];
|
||||
const firstModel = firstGroup?.providers[0];
|
||||
if (firstModel) setReplacementProviderId(firstModel.image_provider_id);
|
||||
}
|
||||
}, [disconnectProvider]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-6">
|
||||
@@ -175,6 +241,11 @@ export default function ImageGenerationContent() {
|
||||
onSelect={() => handleSelect(provider)}
|
||||
onDeselect={() => handleDeselect(provider)}
|
||||
onEdit={() => handleEdit(provider)}
|
||||
onDisconnect={
|
||||
getStatus(provider) !== "disconnected"
|
||||
? () => setDisconnectProvider(provider)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@@ -182,6 +253,105 @@ export default function ImageGenerationContent() {
|
||||
))}
|
||||
</div>
|
||||
|
||||
{disconnectProvider && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectProvider.title}`}
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={() => {
|
||||
setDisconnectProvider(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={() => void handleDisconnect()}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> is currently the default
|
||||
image generation model. Session history will be preserved.
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" text04>
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => setReplacementProviderId(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement model" />
|
||||
<InputSelect.Content>
|
||||
{replacementGroups.map((group) => (
|
||||
<InputSelect.Group key={group.name}>
|
||||
<InputSelect.Label>{group.name}</InputSelect.Label>
|
||||
{group.providers.map((p) => (
|
||||
<InputSelect.Item
|
||||
key={p.image_provider_id}
|
||||
value={p.image_provider_id}
|
||||
icon={() => (
|
||||
<ProviderIcon
|
||||
provider={p.provider_name}
|
||||
size={16}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
{p.title}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Group>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item
|
||||
value={NO_DEFAULT_VALUE}
|
||||
icon={SvgSlash}
|
||||
>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03">
|
||||
{" "}
|
||||
(Disable Image Generation)
|
||||
</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> is currently the default
|
||||
image generation model.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Connect another provider to continue using image generation.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> models will no longer be used
|
||||
to generate images.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
|
||||
{activeProvider && (
|
||||
<modal.Provider>
|
||||
<ImageGenerationConnectionModal
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
BedrockModelResponse,
|
||||
LMStudioModelResponse,
|
||||
LiteLLMProxyModelResponse,
|
||||
BifrostModelResponse,
|
||||
ModelConfiguration,
|
||||
LLMProviderName,
|
||||
BedrockFetchParams,
|
||||
@@ -30,8 +31,9 @@ import {
|
||||
LMStudioFetchParams,
|
||||
OpenRouterFetchParams,
|
||||
LiteLLMProxyFetchParams,
|
||||
BifrostFetchParams,
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgOpenrouter } from "@opal/icons";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
@@ -41,6 +43,7 @@ export const AGGREGATOR_PROVIDERS = new Set([
|
||||
"ollama_chat",
|
||||
"lm_studio",
|
||||
"litellm_proxy",
|
||||
"bifrost",
|
||||
"vertex_ai",
|
||||
]);
|
||||
|
||||
@@ -78,6 +81,7 @@ export const getProviderIcon = (
|
||||
bedrock_converse: SvgAws,
|
||||
openrouter: SvgOpenrouter,
|
||||
litellm_proxy: LiteLLMIcon,
|
||||
bifrost: SvgBifrost,
|
||||
vertex_ai: GeminiIcon,
|
||||
};
|
||||
|
||||
@@ -263,8 +267,11 @@ export const fetchOpenRouterModels = async (
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse OpenRouter model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
@@ -319,8 +326,11 @@ export const fetchLMStudioModels = async (
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse LM Studio model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
@@ -343,6 +353,64 @@ export const fetchLMStudioModels = async (
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches Bifrost models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchBifrostModels = async (
|
||||
params: BifrostFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/admin/llm/bifrost/available-models", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: params.api_key,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse Bifrost model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: BifrostModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: modelData.supports_reasoning,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches LiteLLM Proxy models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
@@ -456,6 +524,13 @@ export const fetchModels = async (
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
case LLMProviderName.BIFROST:
|
||||
return fetchBifrostModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
default:
|
||||
return { models: [], error: `Unknown provider: ${providerName}` };
|
||||
}
|
||||
@@ -469,6 +544,7 @@ export function canProviderFetchModels(providerName?: string) {
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
case LLMProviderName.OPENROUTER:
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
case LLMProviderName.BIFROST:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -7,7 +7,9 @@ import {
|
||||
FailedConnectorIndexingStatus,
|
||||
ValidStatuses,
|
||||
} from "@/lib/types";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Title from "@/components/ui/title";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
@@ -199,45 +201,30 @@ export default function UpgradingPage({
|
||||
/>
|
||||
)}
|
||||
|
||||
<Text className="my-4">
|
||||
{futureEmbeddingModel.switchover_type === "active_only" ? (
|
||||
<>
|
||||
The table below shows the re-indexing progress of active
|
||||
(non-paused) connectors. Once all active connectors have
|
||||
been re-indexed successfully, the new model will be used
|
||||
for all search queries. Paused connectors will continue
|
||||
to be indexed in the background but won't block the
|
||||
switchover. Until then, we will use the old model so
|
||||
that no downtime is necessary during this transition.
|
||||
<br />
|
||||
Note: User file re-indexing progress is not shown. You
|
||||
will see this page until all active connectors are
|
||||
re-indexed!
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
The table below shows the re-indexing progress of all
|
||||
existing connectors. Once all connectors have been
|
||||
re-indexed successfully, the new model will be used for
|
||||
all search queries. Until then, we will use the old
|
||||
model so that no downtime is necessary during this
|
||||
transition.
|
||||
<br />
|
||||
Note: User file re-indexing progress is not shown. You
|
||||
will see this page until all user files are re-indexed!
|
||||
</>
|
||||
)}
|
||||
<Spacer rem={1} />
|
||||
<Text as="p">
|
||||
{futureEmbeddingModel.switchover_type === "active_only"
|
||||
? markdown(
|
||||
"The table below shows the re-indexing progress of active (non-paused) connectors. Once all active connectors have been re-indexed successfully, the new model will be used for all search queries. Paused connectors will continue to be indexed in the background but won't block the switchover. Until then, we will use the old model so that no downtime is necessary during this transition.\nNote: User file re-indexing progress is not shown. You will see this page until all active connectors are re-indexed!"
|
||||
)
|
||||
: markdown(
|
||||
"The table below shows the re-indexing progress of all existing connectors. Once all connectors have been re-indexed successfully, the new model will be used for all search queries. Until then, we will use the old model so that no downtime is necessary during this transition.\nNote: User file re-indexing progress is not shown. You will see this page until all user files are re-indexed!"
|
||||
)}
|
||||
</Text>
|
||||
<Spacer rem={1} />
|
||||
|
||||
{sortedReindexingProgress ? (
|
||||
<>
|
||||
{futureEmbeddingModel.switchover_type === "active_only" &&
|
||||
!hasVisibleReindexingProgress && (
|
||||
<Text className="text-text-700 mt-4">
|
||||
All connectors are currently paused, so none are
|
||||
blocking the switchover. Paused connectors will keep
|
||||
re-indexing in the background.
|
||||
</Text>
|
||||
<>
|
||||
<Spacer rem={1} />
|
||||
<Text as="p">
|
||||
All connectors are currently paused, so none are
|
||||
blocking the switchover. Paused connectors will
|
||||
keep re-indexing in the background.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
{hasVisibleReindexingProgress && (
|
||||
<ReindexingProgressTable
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Title from "@/components/ui/title";
|
||||
import { Button } from "@opal/components";
|
||||
import useSWR from "swr";
|
||||
@@ -107,8 +107,10 @@ function Main() {
|
||||
<div className="px-1 w-full rounded-lg">
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Text className="font-semibold">Multipass Indexing</Text>
|
||||
<Text className="text-text-700">
|
||||
<Text as="p" font="main-ui-action">
|
||||
Multipass Indexing
|
||||
</Text>
|
||||
<Text as="p">
|
||||
{searchSettings.multipass_indexing
|
||||
? "Enabled"
|
||||
: "Disabled"}
|
||||
@@ -116,8 +118,10 @@ function Main() {
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Text className="font-semibold">Contextual RAG</Text>
|
||||
<Text className="text-text-700">
|
||||
<Text as="p" font="main-ui-action">
|
||||
Contextual RAG
|
||||
</Text>
|
||||
<Text as="p">
|
||||
{searchSettings.enable_contextual_rag
|
||||
? "Enabled"
|
||||
: "Disabled"}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { markdown } from "@opal/utils";
|
||||
import Image from "next/image";
|
||||
import { FunctionComponent, useState, useEffect } from "react";
|
||||
import {
|
||||
@@ -436,22 +437,9 @@ export default function VoiceProviderSetupModal({
|
||||
{providerType === "azure" && (
|
||||
<Vertical
|
||||
title="Target URI"
|
||||
subDescription={
|
||||
<>
|
||||
Paste the endpoint shown in{" "}
|
||||
<a
|
||||
href="https://portal.azure.com/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
Azure Portal (Keys and Endpoint)
|
||||
</a>
|
||||
. Onyx extracts the speech region from this URL. Examples:
|
||||
https://westus.api.cognitive.microsoft.com/ or
|
||||
https://westus.tts.speech.microsoft.com/.
|
||||
</>
|
||||
}
|
||||
subDescription={markdown(
|
||||
"Paste the endpoint shown in [Azure Portal (Keys and Endpoint)](https://portal.azure.com/). Onyx extracts the speech region from this URL. Examples: https://westus.api.cognitive.microsoft.com/ or https://westus.tts.speech.microsoft.com/."
|
||||
)}
|
||||
nonInteractive
|
||||
>
|
||||
<InputTypeIn
|
||||
@@ -503,24 +491,14 @@ export default function VoiceProviderSetupModal({
|
||||
{mode === "tts" && (
|
||||
<Vertical
|
||||
title="Voice"
|
||||
subDescription={
|
||||
<>
|
||||
This voice will be used for spoken responses. See full list
|
||||
of supported languages and voices at{" "}
|
||||
<a
|
||||
href={
|
||||
PROVIDER_VOICE_DOCS_URLS[providerType]?.url ??
|
||||
PROVIDER_DOCS_URLS[providerType]
|
||||
}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
{PROVIDER_VOICE_DOCS_URLS[providerType]?.label ?? label}
|
||||
</a>
|
||||
.
|
||||
</>
|
||||
}
|
||||
subDescription={markdown(
|
||||
`This voice will be used for spoken responses. See full list of supported languages and voices at [${
|
||||
PROVIDER_VOICE_DOCS_URLS[providerType]?.label ?? label
|
||||
}](${
|
||||
PROVIDER_VOICE_DOCS_URLS[providerType]?.url ??
|
||||
PROVIDER_DOCS_URLS[providerType]
|
||||
}).`
|
||||
)}
|
||||
nonInteractive
|
||||
>
|
||||
<InputComboBox
|
||||
|
||||
@@ -1,32 +1,25 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState, useReducer } from "react";
|
||||
import { useEffect, useMemo, useState, useReducer } from "react";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgCheckSquare,
|
||||
SvgEdit,
|
||||
SvgGlobe,
|
||||
SvgOnyxLogo,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
|
||||
const route = ADMIN_ROUTES.WEB_SEARCH;
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import {
|
||||
SEARCH_PROVIDERS_URL,
|
||||
SEARCH_PROVIDER_DETAILS,
|
||||
@@ -58,6 +51,10 @@ import {
|
||||
} from "@/app/admin/configuration/web-search/WebProviderModalReducer";
|
||||
import { connectProviderFlow } from "@/app/admin/configuration/web-search/connectProviderFlow";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
const route = ADMIN_ROUTES.WEB_SEARCH;
|
||||
|
||||
interface WebSearchProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
@@ -76,27 +73,151 @@ interface WebContentProviderView {
|
||||
has_api_key: boolean;
|
||||
}
|
||||
|
||||
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
|
||||
isHovered: boolean;
|
||||
onMouseEnter: () => void;
|
||||
onMouseLeave: () => void;
|
||||
children: React.ReactNode;
|
||||
interface DisconnectTargetState {
|
||||
id: number;
|
||||
label: string;
|
||||
category: "search" | "content";
|
||||
providerType: string;
|
||||
}
|
||||
|
||||
function HoverIconButton({
|
||||
isHovered,
|
||||
onMouseEnter,
|
||||
onMouseLeave,
|
||||
children,
|
||||
...buttonProps
|
||||
}: HoverIconButtonProps) {
|
||||
function WebSearchDisconnectModal({
|
||||
disconnectTarget,
|
||||
searchProviders,
|
||||
contentProviders,
|
||||
replacementProviderId,
|
||||
onReplacementChange,
|
||||
onClose,
|
||||
onDisconnect,
|
||||
}: {
|
||||
disconnectTarget: DisconnectTargetState;
|
||||
searchProviders: WebSearchProviderView[];
|
||||
contentProviders: WebContentProviderView[];
|
||||
replacementProviderId: string | null;
|
||||
onReplacementChange: (id: string | null) => void;
|
||||
onClose: () => void;
|
||||
onDisconnect: () => void;
|
||||
}) {
|
||||
const isSearch = disconnectTarget.category === "search";
|
||||
|
||||
// Determine if the target is currently the active/selected provider
|
||||
const isActive = isSearch
|
||||
? searchProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
|
||||
false
|
||||
: contentProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
|
||||
false;
|
||||
|
||||
// Find other configured providers as replacements
|
||||
const replacementOptions = isSearch
|
||||
? searchProviders.filter(
|
||||
(p) => p.id !== disconnectTarget.id && p.id > 0 && p.has_api_key
|
||||
)
|
||||
: contentProviders.filter(
|
||||
(p) =>
|
||||
p.id !== disconnectTarget.id &&
|
||||
p.provider_type !== "onyx_web_crawler" &&
|
||||
p.id > 0 &&
|
||||
p.has_api_key
|
||||
);
|
||||
|
||||
const needsReplacement = isActive;
|
||||
const hasReplacements = replacementOptions.length > 0;
|
||||
|
||||
const getLabel = (p: { name: string; provider_type: string }) => {
|
||||
if (isSearch) {
|
||||
const details =
|
||||
SEARCH_PROVIDER_DETAILS[p.provider_type as WebSearchProviderType];
|
||||
return details?.label ?? p.name ?? p.provider_type;
|
||||
}
|
||||
const details = CONTENT_PROVIDER_DETAILS[p.provider_type];
|
||||
return details?.label ?? p.name ?? p.provider_type;
|
||||
};
|
||||
|
||||
const categoryLabel = isSearch ? "search engine" : "web crawler";
|
||||
const featureLabel = isSearch ? "web search" : "web crawling";
|
||||
const disableLabel = isSearch ? "Disable Web Search" : "Disable Web Crawling";
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && hasReplacements && !replacementProviderId) {
|
||||
const first = replacementOptions[0];
|
||||
if (first) onReplacementChange(String(first.id));
|
||||
}
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
{/* TODO(@raunakab): migrate to opal Button once HoverIconButtonProps typing is resolved */}
|
||||
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
|
||||
{children}
|
||||
</Button>
|
||||
</div>
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectTarget.label}`}
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<OpalButton
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</OpalButton>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.label}</b> is currently the active{" "}
|
||||
{categoryLabel}. Search history will be preserved.
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => onReplacementChange(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement provider" />
|
||||
<InputSelect.Content>
|
||||
{replacementOptions.map((p) => (
|
||||
<InputSelect.Item key={p.id} value={String(p.id)}>
|
||||
{getLabel(p)}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03"> ({disableLabel})</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.label}</b> is currently the active{" "}
|
||||
{categoryLabel}.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Connect another provider to continue using {featureLabel}.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
{isSearch ? "Web search" : "Web crawling"} will no longer be routed
|
||||
through <b>{disconnectTarget.label}</b>.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Search history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -105,6 +226,11 @@ export default function Page() {
|
||||
WebProviderModalReducer,
|
||||
initialWebProviderModalState
|
||||
);
|
||||
const [disconnectTarget, setDisconnectTarget] =
|
||||
useState<DisconnectTargetState | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const [contentModal, dispatchContentModal] = useReducer(
|
||||
WebProviderModalReducer,
|
||||
initialWebProviderModalState
|
||||
@@ -113,8 +239,6 @@ export default function Page() {
|
||||
const [contentActivationError, setContentActivationError] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
data: searchProvidersData,
|
||||
error: searchProvidersError,
|
||||
@@ -833,6 +957,67 @@ export default function Page() {
|
||||
});
|
||||
};
|
||||
|
||||
const handleDisconnectProvider = async () => {
|
||||
if (!disconnectTarget) return;
|
||||
const { id, category } = disconnectTarget;
|
||||
|
||||
try {
|
||||
// If a replacement was selected (not "No Default"), activate it first
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
const repId = Number(replacementProviderId);
|
||||
const activateEndpoint =
|
||||
category === "search"
|
||||
? `/api/admin/web-search/search-providers/${repId}/activate`
|
||||
: `/api/admin/web-search/content-providers/${repId}/activate`;
|
||||
const activateResp = await fetch(activateEndpoint, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
if (!activateResp.ok) {
|
||||
const errorBody = await activateResp.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to activate replacement provider."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`/api/admin/web-search/${category}-providers/${id}`,
|
||||
{ method: "DELETE" }
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch((parseErr) => {
|
||||
console.error("Failed to parse disconnect error response:", parseErr);
|
||||
return {};
|
||||
});
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to disconnect provider."
|
||||
);
|
||||
}
|
||||
|
||||
toast.success(`${disconnectTarget.label} disconnected`);
|
||||
await mutateSearchProviders();
|
||||
await mutateContentProviders();
|
||||
} catch (error) {
|
||||
console.error("Failed to disconnect web search provider:", error);
|
||||
const message =
|
||||
error instanceof Error ? error.message : "Unexpected error occurred.";
|
||||
if (category === "search") {
|
||||
setActivationError(message);
|
||||
} else {
|
||||
setContentActivationError(message);
|
||||
}
|
||||
} finally {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingsLayouts.Root>
|
||||
@@ -894,149 +1079,79 @@ export default function Page() {
|
||||
provider
|
||||
);
|
||||
const isActive = provider?.is_active ?? false;
|
||||
const isHighlighted = isActive;
|
||||
const providerId = provider?.id;
|
||||
const canOpenModal =
|
||||
isBuiltInSearchProviderType(providerType);
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!provider || !isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
disabled: false,
|
||||
icon: "arrow" as const,
|
||||
onClick: canOpenModal
|
||||
const status: "disconnected" | "connected" | "selected" =
|
||||
!isConfigured
|
||||
? "disconnected"
|
||||
: isActive
|
||||
? "selected"
|
||||
: "connected";
|
||||
|
||||
return (
|
||||
<Select
|
||||
key={`${key}-${providerType}`}
|
||||
icon={() =>
|
||||
logoSrc ? (
|
||||
<Image
|
||||
src={logoSrc}
|
||||
alt={`${label} logo`}
|
||||
width={16}
|
||||
height={16}
|
||||
/>
|
||||
) : (
|
||||
<SvgGlobe size={16} />
|
||||
)
|
||||
}
|
||||
title={label}
|
||||
description={subtitle}
|
||||
status={status}
|
||||
onConnect={
|
||||
canOpenModal
|
||||
? () => {
|
||||
openSearchModal(providerType, provider);
|
||||
setActivationError(null);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
label: "Current Default",
|
||||
disabled: false,
|
||||
icon: "check" as const,
|
||||
onClick: providerId
|
||||
: undefined
|
||||
}
|
||||
onSelect={
|
||||
providerId
|
||||
? () => {
|
||||
void handleActivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onDeselect={
|
||||
providerId
|
||||
? () => {
|
||||
void handleDeactivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
disabled: false,
|
||||
icon: "arrow-circle" as const,
|
||||
onClick: providerId
|
||||
? () => {
|
||||
void handleActivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const buttonKey = `search-${key}-${providerType}`;
|
||||
const isButtonHovered = hoveredButtonKey === buttonKey;
|
||||
const isCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (isCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${key}-${providerType}`}
|
||||
onClick={isCardClickable ? handleCardClick : undefined}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
|
||||
isHighlighted
|
||||
? "border-action-link-05"
|
||||
: "border-border-01",
|
||||
isCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 px-2 py-1">
|
||||
{renderLogo({
|
||||
logoSrc,
|
||||
alt: `${label} logo`,
|
||||
size: 16,
|
||||
isHighlighted,
|
||||
})}
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
if (!canOpenModal) return;
|
||||
: undefined
|
||||
}
|
||||
onEdit={
|
||||
isConfigured && canOpenModal
|
||||
? () => {
|
||||
openSearchModal(
|
||||
providerType as WebSearchProviderType,
|
||||
provider
|
||||
);
|
||||
}}
|
||||
aria-label={`Edit ${label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isButtonHovered}
|
||||
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Disabled
|
||||
disabled={
|
||||
buttonState.disabled || !buttonState.onClick
|
||||
}
|
||||
>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
: undefined
|
||||
}
|
||||
onDisconnect={
|
||||
isConfigured && provider && provider.id > 0
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
id: provider.id,
|
||||
label,
|
||||
category: "search",
|
||||
providerType,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
)}
|
||||
@@ -1076,161 +1191,81 @@ export default function Page() {
|
||||
const isCurrentCrawler =
|
||||
provider.provider_type === currentContentProviderType;
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
icon: "arrow" as const,
|
||||
disabled: false,
|
||||
onClick: () => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
setContentActivationError(null);
|
||||
},
|
||||
};
|
||||
}
|
||||
const status: "disconnected" | "connected" | "selected" =
|
||||
!isConfigured
|
||||
? "disconnected"
|
||||
: isCurrentCrawler
|
||||
? "selected"
|
||||
: "connected";
|
||||
|
||||
if (isCurrentCrawler) {
|
||||
return {
|
||||
label: "Current Crawler",
|
||||
icon: "check" as const,
|
||||
disabled: false,
|
||||
onClick: () => {
|
||||
void handleDeactivateContentProvider(
|
||||
providerId,
|
||||
provider.provider_type
|
||||
);
|
||||
},
|
||||
};
|
||||
}
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
icon: "arrow-circle" as const,
|
||||
disabled: !canActivate,
|
||||
onClick: canActivate
|
||||
? () => {
|
||||
void handleActivateContentProvider(provider);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const contentButtonKey = `content-${provider.provider_type}-${provider.id}`;
|
||||
const isContentButtonHovered =
|
||||
hoveredButtonKey === contentButtonKey;
|
||||
const isContentCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleContentCardClick = () => {
|
||||
if (isContentCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
const contentLogoSrc =
|
||||
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.logoSrc;
|
||||
|
||||
return (
|
||||
<div
|
||||
<Select
|
||||
key={`${provider.provider_type}-${provider.id}`}
|
||||
onClick={
|
||||
isContentCardClickable
|
||||
? handleContentCardClick
|
||||
icon={() =>
|
||||
contentLogoSrc ? (
|
||||
<Image
|
||||
src={contentLogoSrc}
|
||||
alt={`${label} logo`}
|
||||
width={16}
|
||||
height={16}
|
||||
/>
|
||||
) : provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : (
|
||||
<SvgGlobe size={16} />
|
||||
)
|
||||
}
|
||||
title={label}
|
||||
description={subtitle}
|
||||
status={status}
|
||||
selectedLabel="Current Crawler"
|
||||
onConnect={() => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
setContentActivationError(null);
|
||||
}}
|
||||
onSelect={
|
||||
canActivate
|
||||
? () => {
|
||||
void handleActivateContentProvider(provider);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
|
||||
isCurrentCrawler
|
||||
? "border-action-link-05"
|
||||
: "border-border-01",
|
||||
isContentCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 px-2 py-1">
|
||||
{renderLogo({
|
||||
logoSrc:
|
||||
CONTENT_PROVIDER_DETAILS[provider.provider_type]
|
||||
?.logoSrc,
|
||||
alt: `${label} logo`,
|
||||
fallback:
|
||||
provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : undefined,
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
})}
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
openContentModal(
|
||||
provider.provider_type,
|
||||
provider
|
||||
);
|
||||
}}
|
||||
aria-label={`Edit ${label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isContentButtonHovered}
|
||||
onMouseEnter={() =>
|
||||
setHoveredButtonKey(contentButtonKey)
|
||||
onDeselect={() => {
|
||||
void handleDeactivateContentProvider(
|
||||
providerId,
|
||||
provider.provider_type
|
||||
);
|
||||
}}
|
||||
onEdit={
|
||||
provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured
|
||||
? () => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Disabled
|
||||
disabled={
|
||||
buttonState.disabled || !buttonState.onClick
|
||||
}
|
||||
>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
: undefined
|
||||
}
|
||||
onDisconnect={
|
||||
provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured &&
|
||||
provider.id > 0
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
id: provider.id,
|
||||
label,
|
||||
category: "content",
|
||||
providerType: provider.provider_type,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
@@ -1238,6 +1273,21 @@ export default function Page() {
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
|
||||
{disconnectTarget && (
|
||||
<WebSearchDisconnectModal
|
||||
disconnectTarget={disconnectTarget}
|
||||
searchProviders={searchProviders}
|
||||
contentProviders={combinedContentProviders}
|
||||
replacementProviderId={replacementProviderId}
|
||||
onReplacementChange={setReplacementProviderId}
|
||||
onClose={() => {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
onDisconnect={() => void handleDisconnectProvider()}
|
||||
/>
|
||||
)}
|
||||
|
||||
<WebProviderSetupModal
|
||||
isOpen={selectedProviderType !== null}
|
||||
onClose={() => {
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
TableCell,
|
||||
TableHeader,
|
||||
} from "@/components/ui/table";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { CCPairFullInfo } from "./types";
|
||||
import { IndexAttemptSnapshot } from "@/lib/types";
|
||||
@@ -153,17 +153,11 @@ export function IndexAttemptsTable({
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{indexAttempt.status === "success" && (
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
{"-"}
|
||||
</Text>
|
||||
)}
|
||||
{indexAttempt.status === "success" && <Text as="p">-</Text>}
|
||||
|
||||
{indexAttempt.status === "failed" &&
|
||||
indexAttempt.error_msg && (
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
{indexAttempt.error_msg}
|
||||
</Text>
|
||||
<Text as="p">{indexAttempt.error_msg}</Text>
|
||||
)}
|
||||
</TableCell>
|
||||
<td className="w-0 p-0">
|
||||
|
||||
@@ -11,9 +11,10 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Button } from "@opal/components";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { Card } from "@/components/ui/card";
|
||||
import Text from "@/components/ui/text";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { SvgDownloadCloud } from "@opal/icons";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
@@ -75,11 +76,12 @@ function Main() {
|
||||
<>
|
||||
{isDownloading && <Spinner />}
|
||||
<div className="mb-8">
|
||||
<Text className="mb-3">
|
||||
<b>Debug Logs</b> provide detailed information about system operations
|
||||
and events. You can download logs for each category to analyze system
|
||||
behavior or troubleshoot issues.
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"**Debug Logs** provide detailed information about system operations and events. You can download logs for each category to analyze system behavior or troubleshoot issues."
|
||||
)}
|
||||
</Text>
|
||||
<Spacer rem={0.75} />
|
||||
|
||||
{categories.length > 0 && (
|
||||
<Card className="mt-4">
|
||||
|
||||
@@ -10,7 +10,9 @@ import {
|
||||
TableBody,
|
||||
TableCell,
|
||||
} from "@/components/ui/table";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Title from "@/components/ui/title";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { DocumentSetSummary } from "@/lib/types";
|
||||
@@ -393,11 +395,12 @@ function Main() {
|
||||
|
||||
return (
|
||||
<div className="mb-8">
|
||||
<Text className="mb-3">
|
||||
<b>Document Sets</b> allow you to group logically connected documents
|
||||
into a single bundle. These can then be used as a filter when performing
|
||||
searches to control the scope of information Onyx searches over.
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"**Document Sets** allow you to group logically connected documents into a single bundle. These can then be used as a filter when performing searches to control the scope of information Onyx searches over."
|
||||
)}
|
||||
</Text>
|
||||
<Spacer rem={0.75} />
|
||||
|
||||
<div className="mb-3"></div>
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Title from "@/components/ui/title";
|
||||
import {
|
||||
CloudEmbeddingProvider,
|
||||
@@ -99,10 +101,12 @@ export default function CloudEmbeddingPage({
|
||||
<Title className="mt-8">
|
||||
Here are some cloud-based models to choose from.
|
||||
</Title>
|
||||
<Text className="mb-4">
|
||||
These models require API keys and run in the clouds of the respective
|
||||
providers.
|
||||
<Text as="p">
|
||||
{
|
||||
"These models require API keys and run in the clouds of the respective providers."
|
||||
}
|
||||
</Text>
|
||||
<Spacer rem={1} />
|
||||
|
||||
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
|
||||
{providers.map((provider) => (
|
||||
@@ -156,18 +160,11 @@ export default function CloudEmbeddingPage({
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Text className="mt-6">
|
||||
Alternatively, you can use a self-hosted model using the LiteLLM
|
||||
proxy. This allows you to leverage various LLM providers through a
|
||||
unified interface that you control.{" "}
|
||||
<a
|
||||
href="https://docs.litellm.ai/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-blue-500 hover:underline"
|
||||
>
|
||||
Learn more about LiteLLM
|
||||
</a>
|
||||
<Spacer rem={1.5} />
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"Alternatively, you can use a self-hosted model using the LiteLLM proxy. This allows you to leverage various LLM providers through a unified interface that you control. [Learn more about LiteLLM](https://docs.litellm.ai/)"
|
||||
)}
|
||||
</Text>
|
||||
|
||||
<div key={LITELLM_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
|
||||
@@ -214,20 +211,25 @@ export default function CloudEmbeddingPage({
|
||||
{!liteLLMProvider && (
|
||||
<CardSection className="mt-2 w-full max-w-4xl bg-background-50 border border-background-200">
|
||||
<div className="p-4">
|
||||
<Text className="text-lg font-semibold mb-2">
|
||||
<Text as="p" font="heading-h3">
|
||||
API URL Required
|
||||
</Text>
|
||||
<Text className="text-sm text-text-600 mb-4">
|
||||
Before you can add models, you need to provide an API URL
|
||||
for your LiteLLM proxy. Click the "Provide API
|
||||
URL" button above to set up your LiteLLM configuration.
|
||||
<Spacer rem={0.5} />
|
||||
<Text as="p">
|
||||
{
|
||||
'Before you can add models, you need to provide an API URL for your LiteLLM proxy. Click the "Provide API URL" button above to set up your LiteLLM configuration.'
|
||||
}
|
||||
</Text>
|
||||
<Spacer rem={1} />
|
||||
<div className="flex items-center">
|
||||
<FiInfo className="text-blue-500 mr-2" size={18} />
|
||||
<Text className="text-sm text-blue-500">
|
||||
Once configured, you'll be able to add and manage
|
||||
your LiteLLM models here.
|
||||
</Text>
|
||||
<span className="text-blue-500">
|
||||
<Text as="p">
|
||||
{
|
||||
"Once configured, you'll be able to add and manage your LiteLLM models here."
|
||||
}
|
||||
</Text>
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
@@ -281,9 +283,11 @@ export default function CloudEmbeddingPage({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Text className="mt-6">
|
||||
You can also use Azure OpenAI models for embeddings. Azure requires
|
||||
separate configuration for each model.
|
||||
<Spacer rem={1.5} />
|
||||
<Text as="p">
|
||||
{
|
||||
"You can also use Azure OpenAI models for embeddings. Azure requires separate configuration for each model."
|
||||
}
|
||||
</Text>
|
||||
|
||||
<div key={AZURE_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
|
||||
@@ -319,18 +323,22 @@ export default function CloudEmbeddingPage({
|
||||
</button>
|
||||
<div className="mt-2 w-full max-w-4xl">
|
||||
<CardSection className="p-4 border border-background-200 rounded-lg shadow-sm">
|
||||
<Text className="text-base font-medium mb-2">
|
||||
<Text as="p" font="main-ui-action">
|
||||
Configure Azure OpenAI for Embeddings
|
||||
</Text>
|
||||
<Text className="text-sm text-text-600 mb-3">
|
||||
Click "Configure Azure OpenAI" to set up Azure
|
||||
OpenAI for embeddings.
|
||||
<Spacer rem={0.5} />
|
||||
<Text as="p">
|
||||
{
|
||||
'Click "Configure Azure OpenAI" to set up Azure OpenAI for embeddings.'
|
||||
}
|
||||
</Text>
|
||||
<div className="flex items-center text-sm text-text-700">
|
||||
<Spacer rem={0.75} />
|
||||
<div className="flex items-center">
|
||||
<FiInfo className="text-neutral-400 mr-2" size={16} />
|
||||
<Text>
|
||||
You'll need: API version, base URL, API key, model
|
||||
name, and deployment name.
|
||||
<Text as="p">
|
||||
{
|
||||
"You'll need: API version, base URL, API key, model name, and deployment name."
|
||||
}
|
||||
</Text>
|
||||
</div>
|
||||
</CardSection>
|
||||
@@ -339,9 +347,10 @@ export default function CloudEmbeddingPage({
|
||||
) : (
|
||||
<>
|
||||
<div className="mb-6 w-full">
|
||||
<Text className="text-lg font-semibold mb-3">
|
||||
<Text as="p" font="heading-h3">
|
||||
Current Azure Configuration
|
||||
</Text>
|
||||
<Spacer rem={0.75} />
|
||||
|
||||
{azureProviderDetails ? (
|
||||
<CardSection className="bg-white shadow-sm border border-background-200 rounded-lg">
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Title from "@/components/ui/title";
|
||||
import { ModelSelector } from "../../../../components/embedding/ModelSelector";
|
||||
import {
|
||||
@@ -25,41 +27,28 @@ export default function OpenEmbeddingPage({
|
||||
<Title className="mt-8">
|
||||
Here are some locally-hosted models to choose from.
|
||||
</Title>
|
||||
<Text className="mb-4">
|
||||
These models can be used without any API keys, and can leverage a GPU
|
||||
for faster inference.
|
||||
<Text as="p">
|
||||
{
|
||||
"These models can be used without any API keys, and can leverage a GPU for faster inference."
|
||||
}
|
||||
</Text>
|
||||
<Spacer rem={1} />
|
||||
<ModelSelector
|
||||
modelOptions={AVAILABLE_MODELS}
|
||||
setSelectedModel={onSelectOpenSource}
|
||||
currentEmbeddingModel={selectedProvider}
|
||||
/>
|
||||
|
||||
<Text className="mt-6">
|
||||
Alternatively, (if you know what you're doing) you can specify a{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://www.sbert.net/"
|
||||
className="text-link"
|
||||
rel="noreferrer"
|
||||
>
|
||||
SentenceTransformers
|
||||
</a>
|
||||
-compatible model of your choice below. The rough list of supported
|
||||
models can be found{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
|
||||
className="text-link"
|
||||
rel="noreferrer"
|
||||
>
|
||||
here
|
||||
</a>
|
||||
.
|
||||
<br />
|
||||
<b>NOTE:</b> not all models listed will work with Onyx, since some have
|
||||
unique interfaces or special requirements. If in doubt, reach out to the
|
||||
Onyx team.
|
||||
<Spacer rem={1.5} />
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"Alternatively, (if you know what you're doing) you can specify a [SentenceTransformers](https://www.sbert.net/)-compatible model of your choice below. The rough list of supported models can be found [here](https://huggingface.co/models?library=sentence-transformers&sort=trending)."
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"**NOTE:** not all models listed will work with Onyx, since some have unique interfaces or special requirements. If in doubt, reach out to the Onyx team."
|
||||
)}
|
||||
</Text>
|
||||
{!configureModel && (
|
||||
// TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved
|
||||
|
||||
@@ -5,7 +5,9 @@ import { SearchAndFilterControls } from "./SearchAndFilterControls";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Link from "next/link";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import { useConnectorIndexingStatusWithPagination } from "@/lib/hooks";
|
||||
import { useToastFromQuery } from "@/hooks/useToast";
|
||||
import { Button } from "@opal/components";
|
||||
@@ -185,13 +187,14 @@ function Main() {
|
||||
<ConnectorStaggeredSkeleton rowCount={8} standalone={true} />
|
||||
</div>
|
||||
) : !ccPairsIndexingStatuses || ccPairsIndexingStatuses.length === 0 ? (
|
||||
<Text className="mt-12">
|
||||
It looks like you don't have any connectors setup yet. Visit the{" "}
|
||||
<Link className="text-link" href="/admin/add-connector">
|
||||
Add Connector
|
||||
</Link>{" "}
|
||||
page to get started!
|
||||
</Text>
|
||||
<div>
|
||||
<Spacer rem={3} />
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"It looks like you don't have any connectors setup yet. Visit the [Add Connector](/admin/add-connector) page to get started!"
|
||||
)}
|
||||
</Text>
|
||||
</div>
|
||||
) : (
|
||||
<CCPairIndexingStatusTable
|
||||
ccPairsIndexingStatuses={ccPairsIndexingStatuses}
|
||||
|
||||
@@ -16,7 +16,8 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import { TableHeader } from "@/components/ui/table";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
|
||||
type TokenRateLimitTableArgs = {
|
||||
tokenRateLimits: TokenRateLimitDisplay[];
|
||||
@@ -68,11 +69,15 @@ export const TokenRateLimitTable = ({
|
||||
<div className="w-full">
|
||||
{!hideHeading && title && <Title>{title}</Title>}
|
||||
{!hideHeading && description && (
|
||||
<Text className="my-2">{description}</Text>
|
||||
<>
|
||||
<Spacer rem={0.5} />
|
||||
<Text as="p">{description}</Text>
|
||||
<Spacer rem={0.5} />
|
||||
</>
|
||||
)}
|
||||
<Text className={`${!hideHeading && "my-8"}`}>
|
||||
No token rate limits set!
|
||||
</Text>
|
||||
{!hideHeading && <Spacer rem={2} />}
|
||||
<Text as="p">No token rate limits set!</Text>
|
||||
{!hideHeading && <Spacer rem={2} />}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -81,7 +86,11 @@ export const TokenRateLimitTable = ({
|
||||
<div className="w-full">
|
||||
{!hideHeading && title && <Title>{title}</Title>}
|
||||
{!hideHeading && description && (
|
||||
<Text className="my-2">{description}</Text>
|
||||
<>
|
||||
<Spacer rem={0.5} />
|
||||
<Text as="p">{description}</Text>
|
||||
<Spacer rem={0.5} />
|
||||
</>
|
||||
)}
|
||||
<Table
|
||||
className={`overflow-visible ${
|
||||
@@ -188,7 +197,7 @@ export const GenericTokenRateLimitTable = ({
|
||||
}
|
||||
|
||||
if (!isLoading && error) {
|
||||
return <Text>Failed to load token rate limits</Text>;
|
||||
return <Text as="p">Failed to load token rate limits</Text>;
|
||||
}
|
||||
|
||||
let processedData = data;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import SimpleTabs from "@/refresh-components/SimpleTabs";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
insertGlobalTokenRateLimit,
|
||||
@@ -104,14 +104,14 @@ function Main() {
|
||||
|
||||
return (
|
||||
<Section alignItems="stretch" justifyContent="start" height="auto">
|
||||
<Text>
|
||||
<Text as="p">
|
||||
Token rate limits enable you control how many tokens can be spent in a
|
||||
given time period. With token rate limits, you can:
|
||||
</Text>
|
||||
|
||||
<ul className="list-disc ml-4">
|
||||
<li>
|
||||
<Text>
|
||||
<Text as="p">
|
||||
Set a global rate limit to control your team's overall token
|
||||
spend.
|
||||
</Text>
|
||||
@@ -119,13 +119,13 @@ function Main() {
|
||||
{isPaidEnterpriseFeaturesEnabled && (
|
||||
<>
|
||||
<li>
|
||||
<Text>
|
||||
<Text as="p">
|
||||
Set rate limits for users to ensure that no single user can
|
||||
spend too many tokens.
|
||||
</Text>
|
||||
</li>
|
||||
<li>
|
||||
<Text>
|
||||
<Text as="p">
|
||||
Set rate limits for user groups to control token spend for your
|
||||
teams.
|
||||
</Text>
|
||||
@@ -133,7 +133,7 @@ function Main() {
|
||||
</>
|
||||
)}
|
||||
<li>
|
||||
<Text>Enable and disable rate limits on the fly.</Text>
|
||||
<Text as="p">Enable and disable rate limits on the fly.</Text>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@ import React, { useState } from "react";
|
||||
import { forgotPassword } from "./utils";
|
||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||
import Title from "@/components/ui/title";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Link from "next/link";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
@@ -73,12 +75,11 @@ const ForgotPasswordPage: React.FC = () => {
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
<Spacer rem={1} />
|
||||
<div className="flex">
|
||||
<Text className="mt-4 mx-auto">
|
||||
<Link href="/auth/login" className="text-link font-medium">
|
||||
Back to Login
|
||||
</Link>
|
||||
</Text>
|
||||
<div className="mx-auto">
|
||||
<Text as="p">{markdown("[Back to Login](/auth/login)")}</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</AuthFlowContainer>
|
||||
|
||||
@@ -3,7 +3,9 @@ import React, { useState, useEffect } from "react";
|
||||
import { resetPassword } from "../forgot-password/utils";
|
||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||
import Title from "@/components/ui/title";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Link from "next/link";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
@@ -109,12 +111,11 @@ const ResetPasswordPage: React.FC = () => {
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
<Spacer rem={1} />
|
||||
<div className="flex">
|
||||
<Text className="mt-4 mx-auto">
|
||||
<Link href="/auth/login" className="text-link font-medium">
|
||||
Back to Login
|
||||
</Link>
|
||||
</Text>
|
||||
<div className="mx-auto">
|
||||
<Text as="p">{markdown("[Back to Login](/auth/login)")}</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</AuthFlowContainer>
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
|
||||
import { User } from "@/lib/types";
|
||||
import Logo from "@/refresh-components/Logo";
|
||||
@@ -65,17 +66,22 @@ export default function Verify({ user }: VerifyProps) {
|
||||
<div className="min-h-screen flex flex-col items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
||||
<Logo folded size={64} className="mx-auto w-fit animate-pulse" />
|
||||
{!error ? (
|
||||
<Text className="mt-2">Verifying your email...</Text>
|
||||
<>
|
||||
<Spacer rem={0.5} />
|
||||
<Text as="p">Verifying your email...</Text>
|
||||
</>
|
||||
) : (
|
||||
<div>
|
||||
<Text className="mt-2">{error}</Text>
|
||||
<Spacer rem={0.5} />
|
||||
<Text as="p">{error}</Text>
|
||||
|
||||
{user && (
|
||||
<div className="text-center">
|
||||
<RequestNewVerificationEmail email={user.email}>
|
||||
<Text className="mt-2 text-link">
|
||||
{/* TODO(@raunakab): migrate to @opal/components Text */}
|
||||
<p className="text-sm mt-2 text-link">
|
||||
Get new verification email
|
||||
</Text>
|
||||
</p>
|
||||
</RequestNewVerificationEmail>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -4,11 +4,11 @@ import {
|
||||
getCurrentUserSS,
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
import { User } from "@/lib/types";
|
||||
import Text from "@/components/ui/text";
|
||||
import { RequestNewVerificationEmail } from "./RequestNewVerificationEmail";
|
||||
import Logo from "@/refresh-components/Logo";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
export default async function Page() {
|
||||
// catch cases where the backend is completely unreachable here
|
||||
@@ -35,22 +35,21 @@ export default async function Page() {
|
||||
|
||||
return (
|
||||
<main>
|
||||
<div className="min-h-screen flex flex-col items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
||||
<div className="min-h-screen flex flex-col items-center justify-center py-12 px-4 sm:px-6 lg:px-8 gap-4">
|
||||
<Logo folded size={64} className="mx-auto w-fit" />
|
||||
<div className="flex">
|
||||
<Text className="text-center font-medium text-lg mt-6 w-108">
|
||||
Hey <i>{currentUser.email}</i> - it looks like you haven't
|
||||
verified your email yet.
|
||||
<br />
|
||||
Check your inbox for an email from us to get started!
|
||||
<br />
|
||||
<br />
|
||||
If you don't see anything, click{" "}
|
||||
<RequestNewVerificationEmail email={currentUser.email}>
|
||||
here
|
||||
</RequestNewVerificationEmail>{" "}
|
||||
to request a new email.
|
||||
<div className="flex flex-col gap-2">
|
||||
<Text as="span">
|
||||
{markdown(
|
||||
`Hey, *${currentUser.email}*, it looks like you haven't verified your email yet.\nCheck your inbox for an email from us to get started!`
|
||||
)}
|
||||
</Text>
|
||||
<div className="flex flex-row items-center gap-1">
|
||||
<Text as="span">If you don't see anything, click</Text>
|
||||
<RequestNewVerificationEmail email={currentUser.email}>
|
||||
<Text as="span">here</Text>
|
||||
</RequestNewVerificationEmail>
|
||||
<Text as="span">to request a new email.</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
background-color: var(--background-neutral-00);
|
||||
border: 1px solid var(--status-error-05);
|
||||
}
|
||||
.input-error:focus:not(:active),
|
||||
.input-error:focus-within:not(:active) {
|
||||
box-shadow: inset 0px 0px 0px 2px var(--background-tint-04);
|
||||
}
|
||||
|
||||
.input-disabled {
|
||||
background-color: var(--background-neutral-03);
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import { Label, SubLabel } from "@/components/Field";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import { Button } from "@opal/components";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Text from "@/components/ui/text";
|
||||
import { useContext, useState } from "react";
|
||||
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
@@ -54,17 +54,17 @@ export function CustomAnalyticsUpdateForm() {
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Label>Script</Label>
|
||||
<Text className="mb-3">
|
||||
<Text as="p">
|
||||
Specify the Javascript that should run on page load in order to
|
||||
initialize your custom tracking/analytics.
|
||||
</Text>
|
||||
<Text className="mb-2">
|
||||
Do not include the{" "}
|
||||
<span className="font-mono"><script></script></span>{" "}
|
||||
tags. If you upload a script below but you are not recieving any
|
||||
events in your analytics platform, try removing all extra whitespace
|
||||
before each line of JavaScript.
|
||||
<Spacer rem={0.75} />
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"Do not include the `<script></script>` tags. If you upload a script below but you are not receiving any events in your analytics platform, try removing all extra whitespace before each line of JavaScript."
|
||||
)}
|
||||
</Text>
|
||||
<Spacer rem={0.5} />
|
||||
<InputTextArea
|
||||
value={newCustomAnalyticsScript}
|
||||
onChange={(event) =>
|
||||
|
||||
@@ -2,7 +2,8 @@ import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { CUSTOM_ANALYTICS_ENABLED } from "@/lib/constants";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import { CustomAnalyticsUpdateForm } from "./CustomAnalyticsUpdateForm";
|
||||
|
||||
const route = ADMIN_ROUTES.CUSTOM_ANALYTICS;
|
||||
@@ -24,11 +25,12 @@ function Main() {
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Text className="mb-8">
|
||||
This allows you to bring your own analytics tool to Onyx! Copy the Web
|
||||
snippet from your analytics provider into the box below, and we'll
|
||||
start sending usage events.
|
||||
<Text as="p">
|
||||
{
|
||||
"This allows you to bring your own analytics tool to Onyx! Copy the Web snippet from your analytics provider into the box below, and we'll start sending usage events."
|
||||
}
|
||||
</Text>
|
||||
<Spacer rem={2} />
|
||||
|
||||
<CustomAnalyticsUpdateForm />
|
||||
</div>
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"use client";
|
||||
import { use } from "react";
|
||||
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Title from "@/components/ui/title";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import { ChatSessionSnapshot, MessageSnapshot } from "../../usage/types";
|
||||
import { FiBook } from "react-icons/fi";
|
||||
import { timestampToReadableDate } from "@/lib/dateUtils";
|
||||
@@ -21,13 +22,13 @@ function MessageDisplay({ message }: { message: MessageSnapshot }) {
|
||||
<p className="text-xs font-bold mb-1">
|
||||
{message.message_type === "user" ? "User" : "AI"}
|
||||
</p>
|
||||
<Text>{message.message}</Text>
|
||||
<Text as="p">{message.message}</Text>
|
||||
{message.documents.length > 0 && (
|
||||
<div className="flex flex-col gap-y-2 mt-2">
|
||||
<p className="font-bold text-xs">Reference Documents</p>
|
||||
{message.documents.slice(0, 5).map((document) => {
|
||||
return (
|
||||
<Text className="flex" key={document.document_id}>
|
||||
<div className="text-sm flex" key={document.document_id}>
|
||||
<FiBook
|
||||
className={
|
||||
"my-auto mr-1" + (document.link ? " text-link" : " ")
|
||||
@@ -45,7 +46,7 @@ function MessageDisplay({ message }: { message: MessageSnapshot }) {
|
||||
) : (
|
||||
document.semantic_identifier
|
||||
)}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
@@ -53,7 +54,7 @@ function MessageDisplay({ message }: { message: MessageSnapshot }) {
|
||||
{message.feedback_type && (
|
||||
<div className="mt-2">
|
||||
<p className="font-bold text-xs">Feedback</p>
|
||||
{message.feedback_text && <Text>{message.feedback_text}</Text>}
|
||||
{message.feedback_text && <Text as="p">{message.feedback_text}</Text>}
|
||||
<div className="mt-1">
|
||||
<FeedbackBadge feedback={message.feedback_type} />
|
||||
</div>
|
||||
@@ -95,14 +96,19 @@ export default function QueryPage(props: { params: Promise<{ id: string }> }) {
|
||||
<CardSection className="mt-4">
|
||||
<Title>Chat Session Details</Title>
|
||||
|
||||
<Text className="flex flex-wrap whitespace-normal mt-1 text-sm">
|
||||
{chatSessionSnapshot.assistant_name}
|
||||
</Text>
|
||||
<Text className="flex flex-wrap whitespace-normal mt-1 text-xs">
|
||||
{chatSessionSnapshot.user_email &&
|
||||
`${chatSessionSnapshot.user_email}, `}
|
||||
{timestampToReadableDate(chatSessionSnapshot.time_created)},{" "}
|
||||
{chatSessionSnapshot.flow_type}
|
||||
<Spacer rem={0.25} />
|
||||
{chatSessionSnapshot.assistant_name && (
|
||||
<Text as="p">{chatSessionSnapshot.assistant_name}</Text>
|
||||
)}
|
||||
<Spacer rem={0.25} />
|
||||
<Text as="p">
|
||||
{`${
|
||||
chatSessionSnapshot.user_email
|
||||
? `${chatSessionSnapshot.user_email}, `
|
||||
: ""
|
||||
}${timestampToReadableDate(chatSessionSnapshot.time_created)}, ${
|
||||
chatSessionSnapshot.flow_type
|
||||
}`}
|
||||
</Text>
|
||||
|
||||
<Separator />
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { getDatesList, useQueryAnalytics } from "../lib";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Title from "@/components/ui/title";
|
||||
|
||||
import { DateRangePickerValue } from "@/components/dateRangeSelectors/AdminDateRangeSelector";
|
||||
@@ -68,7 +68,7 @@ export function FeedbackChart({
|
||||
return (
|
||||
<CardSection className="mt-8">
|
||||
<Title>Feedback</Title>
|
||||
<Text>Thumbs Up / Thumbs Down over time</Text>
|
||||
<Text as="p">Thumbs Up / Thumbs Down over time</Text>
|
||||
{chart}
|
||||
</CardSection>
|
||||
);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { getDatesList, useOnyxBotAnalytics } from "../lib";
|
||||
import { DateRangePickerValue } from "@/components/dateRangeSelectors/AdminDateRangeSelector";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Title from "@/components/ui/title";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { AreaChartDisplay } from "@/components/ui/areaChart";
|
||||
@@ -69,7 +69,7 @@ export function OnyxBotChart({
|
||||
return (
|
||||
<CardSection className="mt-8">
|
||||
<Title>Slack Channel</Title>
|
||||
<Text>Total Queries vs Auto Resolved</Text>
|
||||
<Text as="p">Total Queries vs Auto Resolved</Text>
|
||||
{chart}
|
||||
</CardSection>
|
||||
);
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
usePersonaUniqueUsers,
|
||||
} from "../lib";
|
||||
import { DateRangePickerValue } from "@/components/dateRangeSelectors/AdminDateRangeSelector";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Title from "@/components/ui/title";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { AreaChartDisplay } from "@/components/ui/areaChart";
|
||||
@@ -180,7 +180,9 @@ export function PersonaMessagesChart({
|
||||
<CardSection className="mt-8">
|
||||
<Title>Agent Analytics</Title>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text>Messages and unique users per day for the selected agent</Text>
|
||||
<Text as="p">
|
||||
Messages and unique users per day for the selected agent
|
||||
</Text>
|
||||
<div className="flex items-center gap-4">
|
||||
<Select
|
||||
value={selectedPersonaId?.toString() ?? ""}
|
||||
|
||||
@@ -5,7 +5,7 @@ import { getDatesList, useQueryAnalytics, useUserAnalytics } from "../lib";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { AreaChartDisplay } from "@/components/ui/areaChart";
|
||||
import Title from "@/components/ui/title";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
|
||||
export function QueryPerformanceChart({
|
||||
@@ -98,7 +98,7 @@ export function QueryPerformanceChart({
|
||||
return (
|
||||
<CardSection className="mt-8">
|
||||
<Title>Usage</Title>
|
||||
<Text>Usage over time</Text>
|
||||
<Text as="p">Usage over time</Text>
|
||||
{chart}
|
||||
</CardSection>
|
||||
);
|
||||
|
||||
@@ -12,8 +12,9 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Title from "@/components/ui/title";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
@@ -98,9 +99,8 @@ function GenerateReportInput({
|
||||
return (
|
||||
<div className="mb-8">
|
||||
<Title className="mb-2">Generate Usage Reports</Title>
|
||||
<Text className="mb-8">
|
||||
Generate usage statistics for users in the workspace.
|
||||
</Text>
|
||||
<Text as="p">Generate usage statistics for users in the workspace.</Text>
|
||||
<Spacer rem={2} />
|
||||
<div className="grid gap-2 mb-3">
|
||||
<Popover>
|
||||
<Popover.Trigger asChild>
|
||||
@@ -412,9 +412,9 @@ export default function UsageReports() {
|
||||
isWaitingForReport={isWaitingForReport}
|
||||
/>
|
||||
{timeoutMessage && (
|
||||
<div className="mb-4 p-4 bg-amber-50 dark:bg-amber-950/20 border border-amber-200 dark:border-amber-800 rounded-regular">
|
||||
<div className="mb-4 p-4 bg-status-warning-00 border border-status-warning-02 rounded-regular">
|
||||
<div className="flex items-start gap-2">
|
||||
<div className="text-amber-600 dark:text-amber-500 mt-0.5">
|
||||
<div className="text-status-warning-05 mt-0.5">
|
||||
<svg
|
||||
className="w-5 h-5"
|
||||
fill="none"
|
||||
@@ -430,12 +430,15 @@ export default function UsageReports() {
|
||||
</svg>
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<Text className="text-amber-800 dark:text-amber-200 font-medium mb-1">
|
||||
Report Generation In Progress
|
||||
</Text>
|
||||
<Text className="text-amber-700 dark:text-amber-300 text-sm">
|
||||
{timeoutMessage}
|
||||
</Text>
|
||||
<div className="text-status-warning-05">
|
||||
<Text as="p" font="main-ui-action">
|
||||
Report Generation In Progress
|
||||
</Text>
|
||||
</div>
|
||||
<Spacer rem={0.25} />
|
||||
<div className="text-status-warning-05">
|
||||
<Text as="p">{timeoutMessage}</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -25,7 +25,9 @@ import { deleteStandardAnswer } from "./lib";
|
||||
import { FilterDropdown } from "@/components/search/filtering/FilterDropdown";
|
||||
import { FiTag } from "react-icons/fi";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import { TableHeader } from "@/components/ui/table";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { SvgEdit, SvgTrash } from "@opal/icons";
|
||||
@@ -316,19 +318,17 @@ const StandardAnswersTable = ({
|
||||
<div>
|
||||
{paginatedStandardAnswers.length === 0 && (
|
||||
<div className="flex justify-center">
|
||||
<Text>No matching standard answers found...</Text>
|
||||
<Text as="p">No matching standard answers found...</Text>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{paginatedStandardAnswers.length > 0 && (
|
||||
<>
|
||||
<div className="mt-4">
|
||||
<Text>
|
||||
Ensure that you have added the category to the relevant{" "}
|
||||
<a className="text-link" href="/admin/bots">
|
||||
Slack Bot
|
||||
</a>
|
||||
.
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"Ensure that you have added the category to the relevant [Slack Bot](/admin/bots)."
|
||||
)}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="mt-4 flex justify-center">
|
||||
@@ -389,14 +389,17 @@ function Main() {
|
||||
|
||||
return (
|
||||
<div className="mb-8">
|
||||
<Text className="mb-2">
|
||||
Manage the standard answers for pre-defined questions.
|
||||
<br />
|
||||
Note: Currently, only questions asked from Slack can receive standard
|
||||
answers.
|
||||
<Text as="p">
|
||||
{markdown(
|
||||
"Manage the standard answers for pre-defined questions.\nNote: Currently, only questions asked from Slack can receive standard answers."
|
||||
)}
|
||||
</Text>
|
||||
<Spacer rem={0.5} />
|
||||
{standardAnswers.length == 0 && (
|
||||
<Text className="mb-2">Add your first standard answer below!</Text>
|
||||
<>
|
||||
<Text as="p">Add your first standard answer below!</Text>
|
||||
<Spacer rem={0.5} />
|
||||
</>
|
||||
)}
|
||||
<div className="mb-2"></div>
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import { toast } from "@/hooks/useToast";
|
||||
import CreateCredential from "./actions/CreateCredential";
|
||||
import { CCPairFullInfo } from "@/app/admin/connector/[ccPairId]/types";
|
||||
import ModifyCredential from "./actions/ModifyCredential";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import {
|
||||
buildCCPairInfoUrl,
|
||||
buildSimilarCredentialInfoURL,
|
||||
@@ -185,7 +185,7 @@ export default function CredentialSection({
|
||||
<div className="flex-grow flex flex-col justify-center">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<Text className="font-medium">
|
||||
<Text as="p">
|
||||
{ccPair.credential.name ||
|
||||
`Credential #${ccPair.credential.id}`}
|
||||
</Text>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
|
||||
import { FaNewspaper, FaTrash } from "react-icons/fa";
|
||||
import { TextFormField, TypedFileUploadFormField } from "@/components/Field";
|
||||
@@ -51,7 +51,7 @@ export default function EditCredential({
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-y-6">
|
||||
<Text>
|
||||
<Text as="p">
|
||||
Ensure that you update to a credential with the proper permissions!
|
||||
</Text>
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { TextFormField, BooleanFormField } from "@/components/Field";
|
||||
import { Dispatch, SetStateAction } from "react";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
|
||||
|
||||
@@ -59,10 +60,12 @@ export function CustomEmbeddingModelForm({
|
||||
>
|
||||
{({ isSubmitting, submitForm, errors }) => (
|
||||
<Form>
|
||||
<Text className="text-xl text-text-900 font-bold mb-4">
|
||||
Specify details for your {getFormattedProviderName(embeddingType)}{" "}
|
||||
Provider's model
|
||||
<Text as="p" font="heading-h3">
|
||||
{`Specify details for your ${getFormattedProviderName(
|
||||
embeddingType
|
||||
)} Provider's model`}
|
||||
</Text>
|
||||
<Spacer rem={1} />
|
||||
<TextFormField
|
||||
name="model_name"
|
||||
label="Model Name:"
|
||||
|
||||
@@ -13,7 +13,8 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Text } from "@opal/components";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
import { FiLink, FiMaximize2, FiTrash } from "react-icons/fi";
|
||||
@@ -68,15 +69,21 @@ export function FailedReIndexAttempts({
|
||||
/>
|
||||
)}
|
||||
|
||||
<Text className="text-status-error-05 font-semibold mb-2">
|
||||
Failed Re-indexing Attempts
|
||||
</Text>
|
||||
<Text className="text-status-error-05 mb-4">
|
||||
The table below shows only the failed re-indexing attempts for existing
|
||||
connectors. These failures require immediate attention. Once all
|
||||
connectors have been re-indexed successfully, the new model will be used
|
||||
for all search queries.
|
||||
</Text>
|
||||
<div className="text-status-error-05">
|
||||
<Text as="p" font="main-ui-action">
|
||||
Failed Re-indexing Attempts
|
||||
</Text>
|
||||
</div>
|
||||
<Spacer rem={0.5} />
|
||||
<div className="text-status-error-05">
|
||||
<Text as="p">
|
||||
The table below shows only the failed re-indexing attempts for
|
||||
existing connectors. These failures require immediate attention. Once
|
||||
all connectors have been re-indexed successfully, the new model will
|
||||
be used for all search queries.
|
||||
</Text>
|
||||
</div>
|
||||
<Spacer rem={1} />
|
||||
|
||||
<div>
|
||||
<Table>
|
||||
@@ -114,7 +121,7 @@ export function FailedReIndexAttempts({
|
||||
|
||||
<TableCell>
|
||||
<div>
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
<Text as="p">
|
||||
{reindexingProgress.error_msg || "-"}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
* and support various display sizes.
|
||||
*/
|
||||
import React from "react";
|
||||
import { SvgBifrost } from "@opal/icons";
|
||||
import { render } from "@tests/setup/test-utils";
|
||||
import { GithubIcon, GitbookIcon, ConfluenceIcon } from "./icons";
|
||||
|
||||
@@ -51,4 +52,15 @@ describe("Logo Icons", () => {
|
||||
render(<GithubIcon size={100} className="custom-class" />);
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
test("renders the Bifrost icon with theme-aware colors", () => {
|
||||
const { container } = render(
|
||||
<SvgBifrost size={32} className="custom text-red-500 dark:text-black" />
|
||||
);
|
||||
const icon = container.querySelector("svg");
|
||||
|
||||
expect(icon).toBeInTheDocument();
|
||||
expect(icon).toHaveClass("custom", "text-[#33C19E]", "dark:text-white");
|
||||
expect(icon).not.toHaveClass("text-red-500", "dark:text-black");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export default function Text({
|
||||
children,
|
||||
className,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
className?: string;
|
||||
}) {
|
||||
return <p className={cn("text-sm", className)}>{children}</p>;
|
||||
}
|
||||
@@ -13,6 +13,7 @@ export enum LLMProviderName {
|
||||
VERTEX_AI = "vertex_ai",
|
||||
BEDROCK = "bedrock",
|
||||
LITELLM_PROXY = "litellm_proxy",
|
||||
BIFROST = "bifrost",
|
||||
CUSTOM = "custom",
|
||||
}
|
||||
|
||||
@@ -165,6 +166,21 @@ export interface LiteLLMProxyModelResponse {
|
||||
model_name: string;
|
||||
}
|
||||
|
||||
export interface BifrostFetchParams {
|
||||
api_base?: string;
|
||||
api_key?: string;
|
||||
provider_name?: string;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface BifrostModelResponse {
|
||||
name: string;
|
||||
display_name: string;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface VertexAIFetchParams {
|
||||
model_configurations?: ModelConfiguration[];
|
||||
}
|
||||
@@ -182,5 +198,6 @@ export type FetchModelsParams =
|
||||
| OllamaFetchParams
|
||||
| OpenRouterFetchParams
|
||||
| LiteLLMProxyFetchParams
|
||||
| BifrostFetchParams
|
||||
| VertexAIFetchParams
|
||||
| LMStudioFetchParams;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import type { RichStr } from "@opal/types";
|
||||
import { resolveStr } from "@opal/components/text/InlineMarkdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgXOctagon, SvgAlertCircle } from "@opal/icons";
|
||||
import { useField, useFormikContext } from "formik";
|
||||
@@ -12,9 +14,9 @@ interface OrientationLayoutProps {
|
||||
disabled?: boolean;
|
||||
nonInteractive?: boolean;
|
||||
children?: React.ReactNode;
|
||||
title: string;
|
||||
title: string | RichStr;
|
||||
titleSuffix?: string;
|
||||
description?: string;
|
||||
description?: string | RichStr;
|
||||
optional?: boolean;
|
||||
sizePreset?: "main-content" | "main-ui";
|
||||
}
|
||||
@@ -42,7 +44,7 @@ interface OrientationLayoutProps {
|
||||
* ```
|
||||
*/
|
||||
export interface VerticalLayoutProps extends OrientationLayoutProps {
|
||||
subDescription?: React.ReactNode;
|
||||
subDescription?: string | RichStr;
|
||||
}
|
||||
function VerticalInputLayout({
|
||||
name,
|
||||
@@ -70,7 +72,7 @@ function VerticalInputLayout({
|
||||
{name && <ErrorLayout name={name} />}
|
||||
{subDescription && (
|
||||
<Text secondaryBody text03>
|
||||
{subDescription}
|
||||
{resolveStr(subDescription)}
|
||||
</Text>
|
||||
)}
|
||||
</Section>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user