Compare commits

..

6 Commits

Author SHA1 Message Date
Nik
c222348ff3 fix(review): address Greptile comments on PR1
- Add owner to bare TODO comment
- Add explicit type annotation to _completion_done
- Restore placement field assertions weakened by Emitter refactor
2026-03-31 12:49:29 -07:00
Nik
4ae5e96fab feat(chat): add multi-model parallel streaming (N=2-3 LLMs side-by-side)
Adds support for running 2-3 LLMs in parallel within a single chat turn,
with responses streamed interleaved to the frontend via the merged queue
infrastructure introduced in the preceding PR.

Backend changes
- process_message.py: restore llm_overrides param on build_chat_turn and
  _stream_chat_turn; restore is_multi branching for LLM setup, context
  window sizing, and message ID reservation; add _build_model_display_name
  and handle_multi_model_stream (public multi-model entrypoint)
- db/chat.py: add reserve_multi_model_message_ids (reserves N assistant
  message placeholders sharing the same parent), set_preferred_response
  (marks one response as the user's preferred), and extend
  translate_db_message_to_chat_message_detail with preferred_response_id
  and model_display_name fields
- chat_backend.py: route requests with llm_overrides >1 through
  handle_multi_model_stream; reject non-streaming multi-model requests with
  OnyxError; add /set-preferred-response endpoint

Tests
- test_multi_model_streaming.py: unit tests for _run_models drain loop
  (arrival-order yield, error isolation, cancellation), handle_multi_model_stream
  validation guards, and N=1 backwards-compatibility
2026-03-31 12:49:29 -07:00
Nik
3365a369e2 fix(review): address Greptile comments
- Add owner to bare TODO comment
- Restore placement field assertions weakened by Emitter refactor
2026-03-31 12:49:09 -07:00
Nik
470bda3fb5 refactor(chat): elegance pass on PR1 changed files
process_message.py:
- Fix `skip_clarification` field in ChatTurnSetup: inline comment inside
  the type annotation → separate `#` comment on the line above the field
- Flatten `model_tools` via list comprehension instead of manual extend loop
- `forced_tool_id` membership test: list → set comprehension (O(1) lookup)
- Trim `_run_model` inner-function docstring — private closure doesn't need
  10-line Args block
- Remove redundant inline param comments from `_stream_chat_turn` and
  `handle_stream_message_objects` where the docstring Args section already
  documents them
- Strip duplicate Args/Returns from `handle_stream_message_objects` docstring
  — it delegates entirely to `_stream_chat_turn`

emitter.py:
- Widen `merged_queue` annotation to `Queue[Any]`: Queue is invariant so
  `Queue[tuple[int, Packet]]` can't be passed a `Queue[tuple[int, Packet |
  Exception | object]]`; the emitter is a write-only producer and doesn't
  care what else lives on the queue
2026-03-31 12:16:38 -07:00
Nik
13f511e209 refactor(emitter): clean up string annotation and use model_copy
- Fix `"Queue"` forward-reference annotation → `Queue[tuple[int, Packet]]`
  (Queue is already imported, the string was unnecessary)
- Replace manual Placement field copy with `base.model_copy(update={...})`
- Remove redundant `key` variable (was just `self._model_idx`)
- Tighten docstring
2026-03-31 11:44:28 -07:00
Nik
c5e8ba1eab refactor(chat): replace bus-polling emitter with merged-queue streaming; fix 429 hang
Switch Emitter from a per-model event bus + polling thread to a single
bounded queue shared across all models.  Each emit() call puts directly onto
the queue; the drain loop in _run_models yields packets in arrival order.

Key changes
- emitter.py: remove Bus, get_default_emitter(); add Emitter(merged_queue, model_idx)
- chat_state.py: remove run_chat_loop_with_state_containers (113-line bus-poll loop)
- process_message.py: add ChatTurnSetup dataclass and build_chat_turn(); rewrite
  _stream_chat_turn + _run_models around the merged queue; single-model (N=1)
  path is fully backwards-compatible
- placement.py, override_models.py: add docstrings; LLMOverride gains display_name
- research_agent.py, custom_tool.py: update Emitter call sites
- test_emitter.py: new unit tests for queue routing, model_index tagging, placement

Frontend 429 fix
- lib.tsx: parse response body for human-readable detail on non-2xx responses
  instead of "HTTP error! status: 429"
- useChatController.ts: surface stack.error after the FIFO drain loop exits so
  the catch block replaces the thinking placeholder with an error message
2026-03-30 22:18:48 -07:00
128 changed files with 4651 additions and 7870 deletions

View File

@@ -704,9 +704,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
@@ -789,9 +786,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest

View File

@@ -41,7 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
with:
uv_version: "0.9.9"

View File

@@ -13,7 +13,6 @@ from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
@@ -30,10 +29,9 @@ from shared_configs.configs import TENANT_ID_PREFIX
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
_MAX_TENANTS_PER_RUN = 5
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
@shared_task(
@@ -93,7 +91,8 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
if batch_size < tenants_to_provision:
task_logger.info(
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
f"Capping batch to {batch_size} "
f"(need {tenants_to_provision}, will catch up next cycle)"
)
provisioned = 0
@@ -104,14 +103,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
provisioned += 1
except Exception:
task_logger.exception(
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
f"Failed to provision tenant {i + 1}/{batch_size}, "
"continuing with remaining tenants"
)
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
# Migrate any pool tenants that were provisioned before a new migration was deployed
_migrate_stale_pool_tenants()
except Exception:
task_logger.exception("Error in check_available_tenants task")
@@ -124,46 +121,6 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
)
def _migrate_stale_pool_tenants() -> None:
"""
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
idempotent, tenants already at head are a fast no-op. This ensures pool
tenants are always current so that signup doesn't hit schema mismatches
(e.g. missing columns added after the tenant was pre-provisioned).
"""
with get_session_with_shared_schema() as db_session:
pool_tenants = db_session.query(AvailableTenant).all()
tenant_ids = [t.tenant_id for t in pool_tenants]
if not tenant_ids:
return
task_logger.info(
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
)
for tenant_id in tenant_ids:
try:
run_alembic_migrations(tenant_id)
new_version = get_current_alembic_version(tenant_id)
with get_session_with_shared_schema() as db_session:
tenant = (
db_session.query(AvailableTenant)
.filter_by(tenant_id=tenant_id)
.first()
)
if tenant and tenant.alembic_version != new_version:
task_logger.info(
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
)
tenant.alembic_version = new_version
db_session.commit()
except Exception:
task_logger.exception(
f"Failed to migrate pool tenant {tenant_id}, skipping"
)
def pre_provision_tenant() -> bool:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.

View File

@@ -99,26 +99,6 @@ async def get_or_provision_tenant(
tenant_id = await get_available_tenant()
if tenant_id:
# Run migrations to ensure the pre-provisioned tenant schema is current.
# Pool tenants may have been created before a new migration was deployed.
# Capture as a non-optional local so mypy can type the lambda correctly.
_tenant_id: str = tenant_id
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None, lambda: run_alembic_migrations(_tenant_id)
)
except Exception:
# The tenant was already dequeued from the pool — roll it back so
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
logger.exception(
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
)
try:
await rollback_tenant_provisioning(_tenant_id)
except Exception:
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
raise
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")

View File

@@ -100,7 +100,6 @@ def get_model_app() -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -20,7 +20,6 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx import __version__
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
@@ -66,7 +65,6 @@ if SENTRY_DSN:
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:
@@ -517,8 +515,7 @@ def reset_tenant_id(
def wait_for_vespa_or_shutdown(
sender: Any, # noqa: ARG001
**kwargs: Any, # noqa: ARG001
sender: Any, **kwargs: Any # noqa: ARG001
) -> None: # noqa: ARG001
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""

View File

@@ -9,7 +9,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from onyx import __version__
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -138,7 +137,6 @@ def _docfetching_task(
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -319,11 +319,6 @@ def monitor_indexing_attempt_progress(
)
current_db_time = get_db_current_time(db_session)
total_batches: int | str = (
coordination_status.total_batches
if coordination_status.total_batches is not None
else "?"
)
if coordination_status.found:
task_logger.info(
f"Indexing attempt progress: "
@@ -331,7 +326,7 @@ def monitor_indexing_attempt_progress(
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id} "
f"completed_batches={coordination_status.completed_batches} "
f"total_batches={total_batches} "
f"total_batches={coordination_status.total_batches or '?'} "
f"total_docs={coordination_status.total_docs} "
f"total_failures={coordination_status.total_failures}"
f"elapsed={(current_db_time - attempt.time_created).seconds}"

View File

@@ -1,19 +1,8 @@
import threading
import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.tools.models import ToolCallInfo
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
# Type alias for search doc deduplication key
# Simple key: just document_id (str)
@@ -159,114 +148,3 @@ class ChatStateContainer:
"""Thread-safe getter for emitted citations (returns a copy)."""
with self._lock:
return self._emitted_citations.copy()
def run_chat_loop_with_state_containers(
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
with event streaming capabilities.
The wrapped function should accept emitter as first arg and use it to emit
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
)
for packet in packets:
# Process packets
pass
"""
def run_with_exception_capture() -> None:
try:
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)
# Run the function in a background thread
thread = run_in_background(run_with_exception_capture)
pkt: Packet | None = None
last_turn_index = 0 # Track the highest turn_index seen for stop packet
last_cancel_check = time.monotonic()
cancel_check_interval = 0.3 # Check for cancellation every 300ms
try:
while True:
# Poll queue with 300ms timeout for natural stop signal checking
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
try:
pkt = emitter.bus.get(timeout=0.3)
except Empty:
if not is_connected():
# Stop signal detected
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = time.monotonic()
continue
if pkt is not None:
# Track the highest turn_index for the stop packet
if pkt.placement and pkt.placement.turn_index > last_turn_index:
last_turn_index = pkt.placement.turn_index
if isinstance(pkt.obj, OverallStop):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
# Check for cancellation periodically even when packets are flowing
# This ensures stop signal is checked during active streaming
current_time = time.monotonic()
if current_time - last_cancel_check >= cancel_check_interval:
if not is_connected():
# Stop signal detected during streaming
yield Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
)
break
last_cancel_check = current_time
finally:
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
# Skip waiting if user disconnected to exit quickly.
if is_connected():
wait_on_background(thread)
try:
completion_callback(state_container)
except Exception as e:
emitter.emit(
Packet(
placement=Placement(turn_index=last_turn_index + 1),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -1,19 +1,45 @@
import logging
import queue
from queue import Queue
from typing import Any
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
logger = logging.getLogger(__name__)
class Emitter:
"""Use this inside tools to emit arbitrary UI progress."""
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
def __init__(self, bus: Queue):
self.bus = bus
Tags every packet with ``model_index`` and places it on ``merged_queue``
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
Args:
merged_queue: Shared queue owned by ``_run_models``.
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
"""
def __init__(
self,
merged_queue: Queue[Any],
model_idx: int = 0,
) -> None:
self._model_idx = model_idx
self._merged_queue = merged_queue
def emit(self, packet: Packet) -> None:
self.bus.put(packet) # Thread-safe
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter
base = packet.placement or Placement(turn_index=0)
tagged = Packet(
placement=base.model_copy(update={"model_index": self._model_idx}),
obj=packet.obj,
)
try:
self._merged_queue.put((self._model_idx, tagged), timeout=3.0)
except queue.Full:
# Drain loop is gone (e.g. GeneratorExit on disconnect); discard packet.
logger.warning(
"Emitter model_idx=%d: queue full after 3s timeout, dropping packet %s",
self._model_idx,
type(packet.obj).__name__,
)

File diff suppressed because it is too large Load Diff

View File

@@ -212,7 +212,6 @@ class DocumentSource(str, Enum):
PRODUCTBOARD = "productboard"
FILE = "file"
CODA = "coda"
CANVAS = "canvas"
NOTION = "notion"
ZULIP = "zulip"
LINEAR = "linear"
@@ -673,7 +672,6 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
DocumentSource.SLAB: "slab data",
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
DocumentSource.FILE: "files",
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
project management, and collaboration tools into a single, customizable platform",

View File

@@ -1,32 +0,0 @@
"""
Permissioning / AccessControl logic for Canvas courses.
CE stub — returns None (no permissions). The EE implementation is loaded
at runtime via ``fetch_versioned_implementation``.
"""
from collections.abc import Callable
from typing import cast
from onyx.access.models import ExternalAccess
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
def get_course_permissions(
canvas_client: CanvasApiClient,
course_id: int,
) -> ExternalAccess | None:
if not global_version.is_ee_version():
return None
ee_get_course_permissions = cast(
Callable[[CanvasApiClient, int], ExternalAccess | None],
fetch_versioned_implementation(
"onyx.external_permissions.canvas.access",
"get_course_permissions",
),
)
return ee_get_course_permissions(canvas_client, course_id)

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import logging
import re
from collections.abc import Iterator
from typing import Any
from urllib.parse import urlparse
@@ -191,22 +190,3 @@ class CanvasApiClient:
if clean_endpoint:
final_url += "/" + clean_endpoint
return final_url
def paginate(
self,
endpoint: str,
params: dict[str, Any] | None = None,
) -> Iterator[list[Any]]:
"""Yield each page of results, following Link-header pagination.
Makes the first request with endpoint + params, then follows
next_url from Link headers for subsequent pages.
"""
response, next_url = self.get(endpoint, params=params)
while True:
if not response:
break
yield response
if not next_url:
break
response, next_url = self.get(full_url=next_url)

View File

@@ -1,82 +1,17 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.access import get_course_permissions
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
"""Map Canvas API errors to connector framework exceptions."""
if e.status_code == 401:
raise CredentialExpiredError(
"Canvas API token is invalid or expired (HTTP 401)."
)
elif e.status_code == 403:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
)
else:
raise ConnectorValidationError(
f"Canvas API error (status={e.status_code}): {e}"
)
class CanvasCourse(BaseModel):
id: int
name: str | None = None
course_code: str | None = None
created_at: str | None = None
workflow_state: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
return cls(
id=payload["id"],
name=payload.get("name"),
course_code=payload.get("course_code"),
created_at=payload.get("created_at"),
workflow_state=payload.get("workflow_state"),
)
name: str
course_code: str
created_at: str
workflow_state: str
class CanvasPage(BaseModel):
@@ -84,22 +19,10 @@ class CanvasPage(BaseModel):
url: str
title: str
body: str | None = None
created_at: str | None = None
updated_at: str | None = None
created_at: str
updated_at: str
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
return cls(
page_id=payload["page_id"],
url=payload["url"],
title=payload["title"],
body=payload.get("body"),
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
course_id=course_id,
)
class CanvasAssignment(BaseModel):
id: int
@@ -107,23 +30,10 @@ class CanvasAssignment(BaseModel):
description: str | None = None
html_url: str
course_id: int
created_at: str | None = None
updated_at: str | None = None
created_at: str
updated_at: str
due_at: str | None = None
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
return cls(
id=payload["id"],
name=payload["name"],
description=payload.get("description"),
html_url=payload["html_url"],
course_id=course_id,
created_at=payload.get("created_at"),
updated_at=payload.get("updated_at"),
due_at=payload.get("due_at"),
)
class CanvasAnnouncement(BaseModel):
id: int
@@ -133,17 +43,6 @@ class CanvasAnnouncement(BaseModel):
posted_at: str | None = None
course_id: int
@classmethod
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
return cls(
id=payload["id"],
title=payload["title"],
message=payload.get("message"),
html_url=payload["html_url"],
posted_at=payload.get("posted_at"),
course_id=course_id,
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
@@ -173,286 +72,3 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
self.current_course_index += 1
self.stage = "pages"
self.next_url = None
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
SlimConnectorWithPermSync,
):
def __init__(
self,
canvas_base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
self.batch_size = batch_size
self._canvas_client: CanvasApiClient | None = None
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
@property
def canvas_client(self) -> CanvasApiClient:
if self._canvas_client is None:
raise ConnectorMissingCredentialError("Canvas")
return self._canvas_client
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
"""Get course permissions with caching."""
if course_id not in self._course_permissions_cache:
self._course_permissions_cache[course_id] = get_course_permissions(
canvas_client=self.canvas_client,
course_id=course_id,
)
return self._course_permissions_cache[course_id]
@retry(tries=3, delay=1, backoff=2)
def _list_courses(self) -> list[CanvasCourse]:
"""Fetch all courses accessible to the authenticated user."""
logger.debug("Fetching Canvas courses")
courses: list[CanvasCourse] = []
for page in self.canvas_client.paginate(
"courses", params={"per_page": "100", "state[]": "available"}
):
courses.extend(CanvasCourse.from_api(c) for c in page)
return courses
@retry(tries=3, delay=1, backoff=2)
def _list_pages(self, course_id: int) -> list[CanvasPage]:
"""Fetch all pages for a given course."""
logger.debug(f"Fetching pages for course {course_id}")
pages: list[CanvasPage] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/pages",
params={"per_page": "100", "include[]": "body", "published": "true"},
):
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
return pages
@retry(tries=3, delay=1, backoff=2)
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
"""Fetch all assignments for a given course."""
logger.debug(f"Fetching assignments for course {course_id}")
assignments: list[CanvasAssignment] = []
for page in self.canvas_client.paginate(
f"courses/{course_id}/assignments",
params={"per_page": "100", "published": "true"},
):
assignments.extend(
CanvasAssignment.from_api(a, course_id=course_id) for a in page
)
return assignments
@retry(tries=3, delay=1, backoff=2)
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
"""Fetch all announcements for a given course."""
logger.debug(f"Fetching announcements for course {course_id}")
announcements: list[CanvasAnnouncement] = []
for page in self.canvas_client.paginate(
"announcements",
params={
"per_page": "100",
"context_codes[]": f"course_{course_id}",
"active_only": "true",
},
):
announcements.extend(
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
)
return announcements
def _build_document(
self,
doc_id: str,
link: str,
text: str,
semantic_identifier: str,
doc_updated_at: datetime | None,
course_id: int,
doc_type: str,
) -> Document:
"""Build a Document with standard Canvas fields."""
return Document(
id=doc_id,
sections=cast(
list[TextSection | ImageSection],
[TextSection(link=link, text=text)],
),
source=DocumentSource.CANVAS,
semantic_identifier=semantic_identifier,
doc_updated_at=doc_updated_at,
metadata={"course_id": str(course_id), "type": doc_type},
)
def _convert_page_to_document(self, page: CanvasPage) -> Document:
"""Convert a Canvas page to a Document."""
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
text_parts = [page.title]
body_text = parse_html_page_basic(page.body) if page.body else ""
if body_text:
text_parts.append(body_text)
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
link=link,
text="\n\n".join(text_parts),
semantic_identifier=page.title or f"Page {page.page_id}",
doc_updated_at=doc_updated_at,
course_id=page.course_id,
doc_type="page",
)
return document
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
"""Convert a Canvas assignment to a Document."""
text_parts = [assignment.name]
desc_text = (
parse_html_page_basic(assignment.description)
if assignment.description
else ""
)
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
link=assignment.html_url,
text="\n\n".join(text_parts),
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
doc_updated_at=doc_updated_at,
course_id=assignment.course_id,
doc_type="assignment",
)
return document
def _convert_announcement_to_document(
self, announcement: CanvasAnnouncement
) -> Document:
"""Convert a Canvas announcement to a Document."""
text_parts = [announcement.title]
msg_text = (
parse_html_page_basic(announcement.message) if announcement.message else ""
)
if msg_text:
text_parts.append(msg_text)
doc_updated_at = (
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
)
document = self._build_document(
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
link=announcement.html_url,
text="\n\n".join(text_parts),
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
doc_updated_at=doc_updated_at,
course_id=announcement.course_id,
doc_type="announcement",
)
return document
@override
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load and validate Canvas credentials."""
access_token = credentials.get("canvas_access_token")
if not access_token:
raise ConnectorMissingCredentialError("Canvas")
try:
client = CanvasApiClient(
bearer_token=access_token,
canvas_base_url=self.canvas_base_url,
)
client.get("courses", params={"per_page": "1"})
except ValueError as e:
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
except OnyxError as e:
_handle_canvas_api_error(e)
self._canvas_client = client
return None
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
try:
self.canvas_client.get("courses", params={"per_page": "1"})
logger.info("Canvas connector settings validated successfully")
except OnyxError as e:
_handle_canvas_api_error(e)
except ConnectorMissingCredentialError:
raise
except Exception as exc:
raise UnexpectedValidationError(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
# TODO(benwu408): implemented in PR4 (perm sync)
raise NotImplementedError

View File

@@ -72,10 +72,6 @@ CONNECTOR_CLASS_MAP = {
module_path="onyx.connectors.coda.connector",
class_name="CodaConnector",
),
DocumentSource.CANVAS: ConnectorMapping(
module_path="onyx.connectors.canvas.connector",
class_name="CanvasConnector",
),
DocumentSource.NOTION: ConnectorMapping(
module_path="onyx.connectors.notion.connector",
class_name="NotionConnector",

View File

@@ -617,6 +617,92 @@ def reserve_message_id(
return empty_message
def reserve_multi_model_message_ids(
db_session: Session,
chat_session_id: UUID,
parent_message_id: int,
model_display_names: list[str],
) -> list[ChatMessage]:
"""Reserve N assistant message placeholders for multi-model parallel streaming.
All messages share the same parent (the user message). The parent's
latest_child_message_id points to the LAST reserved message so that the
default history-chain walker picks it up.
"""
reserved: list[ChatMessage] = []
for display_name in model_display_names:
msg = ChatMessage(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
latest_child_message_id=None,
message="Response was terminated prior to completion, try regenerating.",
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
message_type=MessageType.ASSISTANT,
model_display_name=display_name,
)
db_session.add(msg)
reserved.append(msg)
# Flush to assign IDs without committing yet
db_session.flush()
# Point parent's latest_child to the last reserved message
parent = (
db_session.query(ChatMessage)
.filter(ChatMessage.id == parent_message_id)
.first()
)
if parent:
parent.latest_child_message_id = reserved[-1].id
db_session.commit()
return reserved
def set_preferred_response(
db_session: Session,
user_message_id: int,
preferred_assistant_message_id: int,
) -> None:
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
Also advances ``latest_child_message_id`` so the preferred response becomes
the active branch for any subsequent messages in the conversation.
Args:
db_session: Active database session.
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
preferred response is being set.
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
Raises:
ValueError: If either message is not found, if ``user_message_id`` does not
refer to a USER message, or if the assistant message is not a direct child
of the user message.
"""
user_msg = db_session.get(ChatMessage, user_message_id)
if user_msg is None:
raise ValueError(f"User message {user_message_id} not found")
if user_msg.message_type != MessageType.USER:
raise ValueError(f"Message {user_message_id} is not a user message")
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
if assistant_msg is None:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} not found"
)
if assistant_msg.parent_message_id != user_message_id:
raise ValueError(
f"Assistant message {preferred_assistant_message_id} is not a child "
f"of user message {user_message_id}"
)
user_msg.preferred_response_id = preferred_assistant_message_id
user_msg.latest_child_message_id = preferred_assistant_message_id
db_session.commit()
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
@@ -839,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
error=chat_message.error,
current_feedback=current_feedback,
processing_duration_seconds=chat_message.processing_duration_seconds,
preferred_response_id=chat_message.preferred_response_id,
model_display_name=chat_message.model_display_name,
)
return chat_msg_detail

View File

@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
def search_for_document_ids(
self,
body: dict[str, Any],
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
) -> list[str]:
"""Searches the index and returns only document chunk IDs.

View File

@@ -60,7 +60,8 @@ class OpenSearchSearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
RANDOM = "random"
DOC_ID_RETRIEVAL = "doc_id_retrieval"
ID_RETRIEVAL = "id_retrieval"
DOCUMENT_IDS = "document_ids"
UNKNOWN = "unknown"

View File

@@ -928,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
search_type=OpenSearchSearchType.ID_RETRIEVAL,
)
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(

View File

@@ -8,6 +8,24 @@ from pydantic import BaseModel
class LLMOverride(BaseModel):
"""Per-request LLM settings that override persona defaults.
All fields are optional — only the fields that differ from the persona's
configured LLM need to be supplied. Used both over the wire (API requests)
and for multi-model comparison, where one override is supplied per model.
Attributes:
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
When ``None``, the persona's default provider is used.
model_version: Specific model version string (e.g. ``"gpt-4o"``).
When ``None``, the persona's default model is used.
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
persona's default temperature is used.
display_name: Human-readable label shown in the UI for this model,
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
when not set.
"""
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None

View File

@@ -439,7 +439,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
release=__version__,
)
logger.info("Sentry initialized")
else:

View File

@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import CreateChatSessionID
from onyx.chat.process_message import gather_stream_full
from onyx.chat.process_message import handle_multi_model_stream
from onyx.chat.process_message import handle_stream_message_objects
from onyx.chat.prompt_utils import get_default_base_system_prompt
from onyx.chat.stop_signal_checker import set_fence
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import set_preferred_response
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.session_loading import (
@@ -570,6 +575,46 @@ def handle_send_chat_message(
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
chat_message_req.origin = MessageOrigin.API
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
is_multi_model = (
chat_message_req.llm_overrides is not None
and len(chat_message_req.llm_overrides) > 1
)
if is_multi_model and chat_message_req.stream:
# Narrowed here; is_multi_model already checked llm_overrides is not None
llm_overrides = chat_message_req.llm_overrides or []
def multi_model_stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as db_session:
for obj in handle_multi_model_stream(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
llm_overrides=llm_overrides,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
mcp_headers=chat_message_req.mcp_headers,
):
yield get_json_line(obj.model_dump())
except Exception as e:
logger.exception("Error in multi-model streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(
multi_model_stream_generator(), media_type="text/event-stream"
)
if is_multi_model and not chat_message_req.stream:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
)
# Non-streaming path: consume all packets and return complete response
if not chat_message_req.stream:
with get_session_with_current_tenant() as db_session:
@@ -660,6 +705,30 @@ def set_message_as_latest(
)
@router.put("/set-preferred-response")
def set_preferred_response_endpoint(
request_body: SetPreferredResponseRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
"""Set the preferred assistant response for a multi-model turn."""
try:
# Ownership check: get_chat_message raises ValueError if the message
# doesn't belong to this user, preventing cross-user mutation.
get_chat_message(
chat_message_id=request_body.user_message_id,
user_id=user.id if user else None,
db_session=db_session,
)
set_preferred_response(
db_session=db_session,
user_message_id=request_body.user_message_id,
preferred_assistant_message_id=request_body.preferred_response_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,

View File

@@ -2,11 +2,25 @@ from pydantic import BaseModel
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
"""Coordinates that identify where a streaming packet belongs in the UI.
The frontend uses these fields to route each packet to the correct turn,
tool tab, agent sub-turn, and (in multi-model mode) response column.
Attributes:
turn_index: Monotonically increasing index of the iterative reasoning block
(e.g. tool call round) within this chat message. Lower values happened first.
tab_index: Disambiguates parallel tool calls within the same turn so each
tool's output can be displayed in its own tab.
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
top-level packets; an integer for tool-within-tool output.
model_index: Which model this packet belongs to. ``0`` for single-model
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
for pre-LLM setup packets (e.g. message ID info) that are yielded
before any Emitter runs.
"""
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
model_index: int | None = None

View File

@@ -1,3 +1,4 @@
import queue
import time
from collections.abc import Callable
from typing import Any
@@ -708,7 +709,6 @@ def run_research_agent_calls(
if __name__ == "__main__":
from queue import Queue
from uuid import uuid4
from onyx.chat.chat_state import ChatStateContainer
@@ -744,8 +744,8 @@ if __name__ == "__main__":
if user is None:
raise ValueError("No users found in database. Please create a user first.")
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
emitter_queue: queue.Queue = queue.Queue()
emitter = Emitter(merged_queue=emitter_queue)
state_container = ChatStateContainer()
tool_dict = construct_tools(
@@ -792,4 +792,4 @@ if __name__ == "__main__":
print(result.intermediate_report)
print("=" * 80)
print(f"Citations: {result.citation_mapping}")
print(f"Total packets emitted: {bus.qsize()}")
print(f"Total packets emitted: {emitter_queue.qsize()}")

View File

@@ -1,5 +1,6 @@
import csv
import json
import queue
import uuid
from io import BytesIO
from io import StringIO
@@ -11,7 +12,6 @@ import requests
from requests import JSONDecodeError
from onyx.chat.emitter import Emitter
from onyx.chat.emitter import get_default_emitter
from onyx.configs.constants import FileOrigin
from onyx.file_store.file_store import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
# Use default emitter if none provided
# Use a discard emitter if none provided (packets go nowhere)
if emitter is None:
emitter = get_default_emitter()
emitter = Emitter(merged_queue=queue.Queue())
return [
CustomTool(
@@ -367,7 +367,7 @@ if __name__ == "__main__":
tools = build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=openapi_schema,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
dynamic_schema_info=None,
)

View File

@@ -5,7 +5,6 @@ import asyncio
import json
import logging
import sys
import time
from dataclasses import asdict
from dataclasses import dataclass
from pathlib import Path
@@ -28,9 +27,6 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
MAX_REQUEST_ATTEMPTS = 5
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
QUESTION_TIMEOUT_SECONDS = 300
QUESTION_RETRY_PAUSE_SECONDS = 30
MAX_QUESTION_ATTEMPTS = 3
@dataclass(frozen=True)
@@ -113,27 +109,6 @@ def normalize_api_base(api_base: str) -> str:
return f"{normalized}/api"
def load_completed_question_ids(output_file: Path) -> set[str]:
if not output_file.exists():
return set()
completed_ids: set[str] = set()
with output_file.open("r", encoding="utf-8") as file:
for line in file:
stripped = line.strip()
if not stripped:
continue
try:
record = json.loads(stripped)
except json.JSONDecodeError:
continue
question_id = record.get("question_id")
if isinstance(question_id, str) and question_id:
completed_ids.add(question_id)
return completed_ids
def load_questions(questions_file: Path) -> list[QuestionRecord]:
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
@@ -373,7 +348,6 @@ async def generate_answers(
api_base: str,
api_key: str,
parallelism: int,
skipped: int,
) -> None:
if parallelism < 1:
raise ValueError("`--parallelism` must be at least 1.")
@@ -408,178 +382,58 @@ async def generate_answers(
write_lock = asyncio.Lock()
completed = 0
successful = 0
stuck_count = 0
failed_questions: list[FailedQuestionRecord] = []
remaining_count = len(questions)
overall_total = remaining_count + skipped
question_durations: list[float] = []
run_start_time = time.monotonic()
def print_progress() -> None:
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
elapsed = time.monotonic() - run_start_time
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
done = skipped + completed
bar_width = 30
filled = (
int(bar_width * done / overall_total)
if overall_total
else bar_width
)
bar = "" * filled + "" * (bar_width - filled)
pct = (done / overall_total * 100) if overall_total else 100.0
parts = (
f"\r{bar} {pct:5.1f}% "
f"[{done}/{overall_total}] "
f"avg {avg_time:.1f}s/q "
f"elapsed {elapsed:.0f}s "
f"ETA {eta:.0f}s "
f"(ok:{successful} fail:{len(failed_questions)}"
)
if stuck_count:
parts += f" stuck:{stuck_count}"
if skipped:
parts += f" skip:{skipped}"
parts += ")"
sys.stderr.write(parts)
sys.stderr.flush()
print_progress()
total = len(questions)
async def process_question(question_record: QuestionRecord) -> None:
nonlocal completed
nonlocal successful
nonlocal stuck_count
last_error: Exception | None = None
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
q_start = time.monotonic()
try:
async with semaphore:
result = await asyncio.wait_for(
submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
),
timeout=QUESTION_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
async with progress_lock:
stuck_count += 1
logger.warning(
"Question %s timed out after %ss (attempt %s/%s, "
"total stuck: %s) — retrying in %ss",
question_record.question_id,
QUESTION_TIMEOUT_SECONDS,
attempt,
MAX_QUESTION_ATTEMPTS,
stuck_count,
QUESTION_RETRY_PAUSE_SECONDS,
)
print_progress()
last_error = TimeoutError(
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
try:
async with semaphore:
result = await submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
)
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
continue
except Exception as exc:
duration = time.monotonic() - q_start
async with progress_lock:
completed += 1
question_durations.append(duration)
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
remaining_count,
)
print_progress()
return
duration = time.monotonic() - q_start
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
except Exception as exc:
async with progress_lock:
completed += 1
successful += 1
question_durations.append(duration)
print_progress()
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
total,
)
return
# All attempts exhausted due to timeouts
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
async with progress_lock:
completed += 1
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(last_error),
)
)
logger.error(
"Question %s failed after %s timeout attempts (%s/%s)",
question_record.question_id,
MAX_QUESTION_ATTEMPTS,
completed,
remaining_count,
)
print_progress()
successful += 1
logger.info("Processed %s/%s questions", completed, total)
await asyncio.gather(
*(process_question(question_record) for question_record in questions)
)
# Final newline after progress bar
sys.stderr.write("\n")
sys.stderr.flush()
total_elapsed = time.monotonic() - run_start_time
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
resume_suffix = (
f"{skipped} previously completed, "
f"{skipped + successful}/{overall_total} overall"
if skipped
else ""
)
logger.info(
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
successful,
remaining_count,
total_elapsed,
avg_time,
stuck_suffix,
resume_suffix,
)
if failed_questions:
logger.warning(
"%s questions failed:",
"Completed with %s failed questions and %s successful questions.",
len(failed_questions),
successful,
)
for failed_question in failed_questions:
logger.warning(
@@ -599,30 +453,7 @@ def main() -> None:
raise ValueError("`--max-questions` must be at least 1 when provided.")
questions = questions[: args.max_questions]
completed_ids = load_completed_question_ids(args.output_file)
logger.info(
"Found %s already-answered question IDs in %s",
len(completed_ids),
args.output_file,
)
total_before_filter = len(questions)
questions = [q for q in questions if q.question_id not in completed_ids]
skipped = total_before_filter - len(questions)
if skipped:
logger.info(
"Resuming: %s/%s already answered, %s remaining",
skipped,
total_before_filter,
len(questions),
)
else:
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
if not questions:
logger.info("All questions already answered. Nothing to do.")
return
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
logger.info("Writing answers to %s", args.output_file)
asyncio.run(
@@ -632,7 +463,6 @@ def main() -> None:
api_base=api_base,
api_key=args.api_key,
parallelism=args.parallelism,
skipped=skipped,
)
)

View File

@@ -27,11 +27,13 @@ def create_placement(
turn_index: int,
tab_index: int = 0,
sub_turn_index: int | None = None,
model_index: int | None = 0,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
model_index=model_index,
)

View File

@@ -13,6 +13,7 @@ This test:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import patch
from uuid import uuid4
@@ -20,7 +21,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=SearchToolConfig(),

View File

@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
All external HTTP calls are mocked, but Postgres and Redis are running.
"""
import queue
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
@@ -16,7 +17,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.chat.emitter import Emitter
from onyx.db.models import OAuthAccount
from onyx.db.models import OAuthConfig
from onyx.db.models import Persona
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
search_tool_config=search_tool_config,
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
tool_dict = construct_tools(
persona=persona,
db_session=db_session,
emitter=get_default_emitter(),
emitter=Emitter(merged_queue=queue.Queue()),
user=user,
llm=llm,
)

View File

@@ -0,0 +1,173 @@
"""Unit tests for the Emitter class.
All tests use the streaming mode (merged_queue required). Emitter has a single
code path — no standalone bus.
"""
import queue
from onyx.chat.emitter import Emitter
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _placement(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
def _packet(
turn_index: int = 0,
tab_index: int = 0,
sub_turn_index: int | None = None,
) -> Packet:
"""Build a minimal valid packet with an OverallStop payload."""
return Packet(
placement=_placement(turn_index, tab_index, sub_turn_index),
obj=OverallStop(stop_reason="test"),
)
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
"""Return (emitter, queue) wired together."""
mq: queue.Queue = queue.Queue()
return Emitter(merged_queue=mq, model_idx=model_idx), mq
# ---------------------------------------------------------------------------
# Queue routing
# ---------------------------------------------------------------------------
class TestEmitterQueueRouting:
def test_emit_lands_on_merged_queue(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet())
assert not mq.empty()
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
item = mq.get_nowait()
assert isinstance(item, tuple)
assert len(item) == 2
def test_multiple_packets_delivered_fifo(self) -> None:
emitter, mq = _make_emitter()
p1 = _packet(turn_index=0)
p2 = _packet(turn_index=1)
emitter.emit(p1)
emitter.emit(p2)
_, t1 = mq.get_nowait()
_, t2 = mq.get_nowait()
assert t1.placement.turn_index == 0
assert t2.placement.turn_index == 1
# ---------------------------------------------------------------------------
# model_index tagging
# ---------------------------------------------------------------------------
class TestEmitterModelIndexTagging:
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
"""N=1: default model_idx=0, so packet gets model_index=0."""
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 0
def test_model_idx_one_tags_packet(self) -> None:
emitter, mq = _make_emitter(model_idx=1)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 1
def test_model_idx_two_tags_packet(self) -> None:
"""Boundary: third model in a 3-model run."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
_key, tagged = mq.get_nowait()
assert tagged.placement.model_index == 2
# ---------------------------------------------------------------------------
# Queue key
# ---------------------------------------------------------------------------
class TestEmitterQueueKey:
def test_key_equals_model_idx(self) -> None:
"""Drain loop uses the key to route packets; it must match model_idx."""
emitter, mq = _make_emitter(model_idx=2)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 2
def test_n1_key_is_zero(self) -> None:
emitter, mq = _make_emitter(model_idx=0)
emitter.emit(_packet())
key, _ = mq.get_nowait()
assert key == 0
# ---------------------------------------------------------------------------
# Placement field preservation
# ---------------------------------------------------------------------------
class TestEmitterPlacementPreservation:
def test_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(turn_index=5))
_, tagged = mq.get_nowait()
assert tagged.placement.turn_index == 5
def test_tab_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(tab_index=3))
_, tagged = mq.get_nowait()
assert tagged.placement.tab_index == 3
def test_sub_turn_index_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=2))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index == 2
def test_sub_turn_index_none_is_preserved(self) -> None:
emitter, mq = _make_emitter()
emitter.emit(_packet(sub_turn_index=None))
_, tagged = mq.get_nowait()
assert tagged.placement.sub_turn_index is None
def test_packet_obj_is_not_modified(self) -> None:
"""The payload object must survive tagging untouched."""
emitter, mq = _make_emitter()
original_obj = OverallStop(stop_reason="sentinel")
pkt = Packet(placement=_placement(), obj=original_obj)
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert tagged.obj is original_obj
def test_different_obj_types_are_handled(self) -> None:
"""Any valid PacketObj type passes through correctly."""
emitter, mq = _make_emitter()
pkt = Packet(placement=_placement(), obj=ReasoningStart())
emitter.emit(pkt)
_, tagged = mq.get_nowait()
assert isinstance(tagged.obj, ReasoningStart)

View File

@@ -0,0 +1,662 @@
"""Unit tests for multi-model streaming validation and DB helpers.
These are pure unit tests — no real database or LLM calls required.
The validation logic in handle_multi_model_stream fires before any external
calls, so we can trigger it with lightweight mocks.
"""
import time
from collections.abc import Generator
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.chat.models import StreamingError
from onyx.configs.constants import MessageType
from onyx.db.chat import set_preferred_response
from onyx.llm.override_models import LLMOverride
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.utils.variable_functionality import global_version
@pytest.fixture(autouse=True)
def _restore_ee_version() -> Generator[None, None, None]:
"""Reset EE global state after each test.
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
(via the celery client import chain). Without this fixture, the EE flag stays
True for the rest of the session and breaks unrelated tests that mock Confluence
or other connectors and assume EE is disabled.
"""
original = global_version._is_ee
yield
global_version._is_ee = original
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_request(**kwargs: Any) -> SendMessageRequest:
defaults: dict[str, Any] = {
"message": "hello",
"chat_session_id": uuid4(),
}
defaults.update(kwargs)
return SendMessageRequest(**defaults)
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
return LLMOverride(model_provider=provider, model_version=version)
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
"""Return the first item yielded by handle_multi_model_stream."""
from onyx.chat.process_message import handle_multi_model_stream
user = MagicMock()
user.is_anonymous = False
user.email = "test@example.com"
db = MagicMock()
gen = handle_multi_model_stream(req, user, db, overrides)
return next(gen)
# ---------------------------------------------------------------------------
# handle_multi_model_stream — validation
# ---------------------------------------------------------------------------
class TestRunMultiModelStreamValidation:
def test_single_override_yields_error(self) -> None:
"""Exactly 1 override is not multi-model — yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [_make_override()])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_four_overrides_yields_error(self) -> None:
"""4 overrides exceeds maximum — yields StreamingError."""
req = _make_request()
result = _first_from_stream(
req,
[
_make_override("openai", "gpt-4"),
_make_override("anthropic", "claude-3"),
_make_override("google", "gemini-pro"),
_make_override("cohere", "command-r"),
],
)
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_zero_overrides_yields_error(self) -> None:
"""Empty override list yields StreamingError."""
req = _make_request()
result = _first_from_stream(req, [])
assert isinstance(result, StreamingError)
assert "2-3" in result.error
def test_deep_research_yields_error(self) -> None:
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
req = _make_request(deep_research=True)
result = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
assert isinstance(result, StreamingError)
assert "not supported" in result.error
def test_exactly_two_overrides_is_minimum(self) -> None:
"""Boundary: 1 override yields error, 2 overrides passes validation."""
req = _make_request()
# 1 override must yield a StreamingError
result = _first_from_stream(req, [_make_override()])
assert isinstance(
result, StreamingError
), "1 override should yield StreamingError"
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
# missing session, that's OK — validation itself passed)
try:
result2 = _first_from_stream(
req, [_make_override(), _make_override("anthropic", "claude-3")]
)
if isinstance(result2, StreamingError) and "2-3" in result2.error:
pytest.fail(
f"2 overrides should pass validation, got StreamingError: {result2.error}"
)
except Exception:
pass # Any non-validation error means validation passed
# ---------------------------------------------------------------------------
# set_preferred_response — validation (mocked db)
# ---------------------------------------------------------------------------
class TestSetPreferredResponseValidation:
def test_user_message_not_found(self) -> None:
db = MagicMock()
db.get.return_value = None
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=999, preferred_assistant_message_id=1
)
def test_wrong_message_type(self) -> None:
"""Cannot set preferred response on a non-USER message."""
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.ASSISTANT # wrong type
db.get.return_value = user_msg
with pytest.raises(ValueError, match="not a user message"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_message_not_found(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
# First call returns user_msg, second call (for assistant) returns None
db.get.side_effect = [user_msg, None]
with pytest.raises(ValueError, match="not found"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_assistant_not_child_of_user(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 999 # different parent
db.get.side_effect = [user_msg, assistant_msg]
with pytest.raises(ValueError, match="not a child"):
set_preferred_response(
db, user_message_id=1, preferred_assistant_message_id=2
)
def test_valid_call_sets_preferred_response_id(self) -> None:
db = MagicMock()
user_msg = MagicMock()
user_msg.message_type = MessageType.USER
assistant_msg = MagicMock()
assistant_msg.parent_message_id = 1 # correct parent
db.get.side_effect = [user_msg, assistant_msg]
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
assert user_msg.preferred_response_id == 2
assert user_msg.latest_child_message_id == 2
# ---------------------------------------------------------------------------
# LLMOverride — display_name field
# ---------------------------------------------------------------------------
class TestLLMOverrideDisplayName:
def test_display_name_defaults_none(self) -> None:
override = LLMOverride(model_provider="openai", model_version="gpt-4")
assert override.display_name is None
def test_display_name_set(self) -> None:
override = LLMOverride(
model_provider="openai",
model_version="gpt-4",
display_name="GPT-4 Turbo",
)
assert override.display_name == "GPT-4 Turbo"
def test_display_name_serializes(self) -> None:
override = LLMOverride(
model_provider="anthropic",
model_version="claude-opus-4-6",
display_name="Claude Opus",
)
d = override.model_dump()
assert d["display_name"] == "Claude Opus"
# ---------------------------------------------------------------------------
# _run_models — drain loop behaviour
# ---------------------------------------------------------------------------
def _make_setup(n_models: int = 1) -> MagicMock:
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
setup = MagicMock()
setup.llms = [MagicMock() for _ in range(n_models)]
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
setup.check_is_connected = MagicMock(return_value=True)
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
setup.reserved_token_count = 100
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
# constructors inside _run_model — must be typed correctly for Pydantic.
setup.new_msg_req.deep_research = False
setup.new_msg_req.internal_search_filters = None
setup.new_msg_req.allowed_tool_ids = None
setup.new_msg_req.include_citations = True
setup.search_params.project_id_filter = None
setup.search_params.persona_id_filter = None
setup.bypass_acl = False
setup.slack_context = None
setup.available_files.user_file_ids = []
setup.available_files.chat_file_ids = []
setup.forced_tool_id = None
setup.simple_chat_history = []
setup.chat_session.id = uuid4()
setup.user_message.id = None
setup.custom_tool_additional_headers = None
setup.mcp_headers = None
return setup
_RUN_MODELS_PATCHES = [
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
]
def _run_models_collect(setup: MagicMock) -> list:
"""Drive _run_models to completion and return all yielded items."""
from onyx.chat.process_message import _run_models
return list(_run_models(setup, MagicMock(), MagicMock()))
class TestRunModels:
"""Tests for the _run_models worker-thread drain loop.
All external dependencies (LLM, DB, tools) are patched out. Worker threads
still run but return immediately since run_llm_loop is mocked.
"""
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
def emit_stop(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(stop_reason="complete"),
)
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert len(stops) == 1
stop_obj = stops[0].obj
assert isinstance(stop_obj, OverallStop)
assert stop_obj.stop_reason == "complete"
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
def emit_one(**kwargs: Any) -> None:
kwargs["emitter"].emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 0
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
def emit_one(**kwargs: Any) -> None:
# _model_idx is set by _run_model based on position in setup.llms
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 2
indices = {p.placement.model_index for p in reasoning}
assert indices == {0, 1}
def test_model_error_yields_streaming_error(self) -> None:
"""An exception inside a worker thread is surfaced as a StreamingError."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("intentional test failure")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=1))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
assert errors[0].error_code == "MODEL_ERROR"
assert "intentional test failure" in errors[0].error
def test_one_model_error_does_not_stop_other_models(self) -> None:
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
emitter = kwargs["emitter"]
# _model_idx is always int (0 for N=1, 0/1/2… for N>1)
if emitter._model_idx == 0:
raise RuntimeError("model 0 failed")
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=fail_model_0_succeed_model_1,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(_make_setup(n_models=2))
errors = [p for p in packets if isinstance(p, StreamingError)]
assert len(errors) == 1
reasoning = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
]
assert len(reasoning) == 1
assert reasoning[0].placement.model_index == 1
def test_cancellation_yields_user_cancelled_stop(self) -> None:
"""If check_is_connected returns False, drain loop emits user_cancelled."""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
setup = _make_setup(n_models=1)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
packets = _run_models_collect(setup)
stops = [
p
for p in packets
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
]
assert any(
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
for s in stops
)
def test_completion_handle_called_on_disconnect(self) -> None:
"""llm_loop_completion_handle must still be called even when user disconnects.
Regression test for the disconnect-cleanup bug: the old
run_chat_loop_with_state_containers always called completion_callback in
its finally block (even on disconnect) so the DB message was updated from
the TERMINATED placeholder to a partial answer. The new _run_models must
replicate this — otherwise the integration test
test_send_message_disconnect_and_cleanup fails because the message stays
as "Response was terminated prior to completion, try regenerating."
"""
def slow_llm(**_kwargs: Any) -> None:
time.sleep(0.3)
setup = _make_setup(n_models=2)
setup.check_is_connected = MagicMock(return_value=False)
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
# Must be called once per model, not zero times
assert mock_handle.call_count == 2
def test_completion_handle_called_for_each_successful_model(self) -> None:
"""llm_loop_completion_handle must be called once per model that succeeded."""
setup = _make_setup(n_models=2)
with (
patch("onyx.chat.process_message.run_llm_loop"),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(setup)
assert mock_handle.call_count == 2
def test_completion_handle_not_called_for_failed_model(self) -> None:
"""llm_loop_completion_handle must be skipped for a model that raised."""
def always_fail(**_kwargs: Any) -> None:
raise RuntimeError("fail")
with (
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
_run_models_collect(_make_setup(n_models=1))
mock_handle.assert_not_called()
def test_http_disconnect_completion_via_generator_exit(self) -> None:
"""GeneratorExit from HTTP disconnect triggers wait+completion in finally.
When the HTTP client closes the connection, Starlette throws GeneratorExit
into the stream generator, which propagates into _run_models. The finally
block must call executor.shutdown(wait=True) to wait for LLM threads to
finish, then persist their results via llm_loop_completion_handle.
This is the primary regression for test_send_message_disconnect_and_cleanup:
the integration test disconnects mid-stream and expects the DB message to be
updated from the TERMINATED placeholder to the real response.
"""
import threading
thread_completed = threading.Event()
def emit_then_complete(**kwargs: Any) -> None:
"""Emit one packet (to give generator a yield point), then finish."""
emitter = kwargs["emitter"]
emitter.emit(
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
)
# Small sleep so executor.shutdown(wait=True) in finally actually waits.
time.sleep(0.05)
thread_completed.set()
setup = _make_setup(n_models=1)
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
setup.check_is_connected = MagicMock(return_value=True)
with (
patch(
"onyx.chat.process_message.run_llm_loop",
side_effect=emit_then_complete,
),
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch(
"onyx.chat.process_message.llm_loop_completion_handle"
) as mock_handle,
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
from onyx.chat.process_message import _run_models
# cast to Generator so .close() is available; _run_models returns
# AnswerStream (= Iterator) but the actual object is always a generator.
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
# Advance to the first yielded packet — generator suspends at `yield item`.
first = next(gen)
assert isinstance(first, Packet)
# Simulate Starlette closing the stream on HTTP client disconnect.
# GeneratorExit is thrown at the `yield item` suspension point.
gen.close()
# Finally block must have waited for the thread and saved completion.
assert (
thread_completed.is_set()
), "LLM thread must complete before gen.close() returns"
assert (
mock_handle.call_count == 1
), "completion handle must be called for the successful model"
def test_external_state_container_used_for_model_zero(self) -> None:
"""When provided, external_state_container is used as state_containers[0]."""
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.process_message import _run_models
external = ChatStateContainer()
setup = _make_setup(n_models=1)
with (
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
patch("onyx.chat.process_message.construct_tools", return_value={}),
patch("onyx.chat.process_message.get_session_with_current_tenant"),
patch("onyx.chat.process_message.llm_loop_completion_handle"),
patch(
"onyx.chat.process_message.get_llm_token_counter",
return_value=lambda _: 0,
),
):
list(
_run_models(
setup, MagicMock(), MagicMock(), external_state_container=external
)
)
# The state_container kwarg passed to run_llm_loop must be the external one
call_kwargs = mock_llm.call_args.kwargs
assert call_kwargs["state_container"] is external

View File

@@ -1,23 +1,15 @@
"""Tests for Canvas connector — client, credentials, conversion."""
"""Tests for Canvas connector — client (PR1)."""
from datetime import datetime
from datetime import timezone
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.canvas.client import CanvasApiClient
from onyx.connectors.canvas.connector import CanvasConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@@ -26,77 +18,6 @@ FAKE_BASE_URL = "https://myschool.instructure.com"
FAKE_TOKEN = "fake-canvas-token"
def _mock_course(
course_id: int = 1,
name: str = "Intro to CS",
course_code: str = "CS101",
) -> dict[str, Any]:
return {
"id": course_id,
"name": name,
"course_code": course_code,
"created_at": "2025-01-01T00:00:00Z",
"workflow_state": "available",
}
def _build_connector(base_url: str = FAKE_BASE_URL) -> CanvasConnector:
"""Build a connector with mocked credential validation."""
with patch("onyx.connectors.canvas.client.rl_requests") as mock_req:
mock_req.get.return_value = _mock_response(json_data=[_mock_course()])
connector = CanvasConnector(canvas_base_url=base_url)
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
return connector
def _mock_page(
page_id: int = 10,
title: str = "Syllabus",
updated_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"page_id": page_id,
"url": "syllabus",
"title": title,
"body": "<p>Welcome to the course</p>",
"created_at": "2025-01-15T00:00:00Z",
"updated_at": updated_at,
}
def _mock_assignment(
assignment_id: int = 20,
name: str = "Homework 1",
course_id: int = 1,
updated_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"id": assignment_id,
"name": name,
"description": "<p>Solve these problems</p>",
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/assignments/{assignment_id}",
"course_id": course_id,
"created_at": "2025-01-20T00:00:00Z",
"updated_at": updated_at,
"due_at": "2025-02-01T23:59:00Z",
}
def _mock_announcement(
announcement_id: int = 30,
title: str = "Class Cancelled",
course_id: int = 1,
posted_at: str = "2025-06-01T12:00:00Z",
) -> dict[str, Any]:
return {
"id": announcement_id,
"title": title,
"message": "<p>No class today</p>",
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/discussion_topics/{announcement_id}",
"posted_at": posted_at,
}
def _mock_response(
status_code: int = 200,
json_data: Any = None,
@@ -404,57 +325,6 @@ class TestGet:
assert result == expected
# ---------------------------------------------------------------------------
# CanvasApiClient.paginate tests
# ---------------------------------------------------------------------------
class TestPaginate:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[{"id": 1}, {"id": 2}]
)
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert len(pages) == 1
assert pages[0] == [{"id": 1}, {"id": 2}]
@patch("onyx.connectors.canvas.client.rl_requests")
def test_two_pages(self, mock_requests: MagicMock) -> None:
next_link = f'<{FAKE_BASE_URL}/api/v1/courses?page=2>; rel="next"'
page1 = _mock_response(json_data=[{"id": 1}], link_header=next_link)
page2 = _mock_response(json_data=[{"id": 2}])
mock_requests.get.side_effect = [page1, page2]
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert len(pages) == 2
assert pages[0] == [{"id": 1}]
assert pages[1] == [{"id": 2}]
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
client = CanvasApiClient(
bearer_token=FAKE_TOKEN,
canvas_base_url=FAKE_BASE_URL,
)
pages = list(client.paginate("courses"))
assert pages == []
# ---------------------------------------------------------------------------
# CanvasApiClient._parse_next_link tests
# ---------------------------------------------------------------------------
@@ -509,368 +379,3 @@ class TestParseNextLink:
with pytest.raises(OnyxError, match="must use https"):
self.client._parse_next_link(header)
# ---------------------------------------------------------------------------
# CanvasConnector — credential loading
# ---------------------------------------------------------------------------
class TestLoadCredentials:
def _assert_load_credentials_raises(
self,
status_code: int,
expected_error: type[Exception],
mock_requests: MagicMock,
) -> None:
"""Helper: assert load_credentials raises expected_error for a given status."""
mock_requests.get.return_value = _mock_response(status_code, {})
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(expected_error):
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_success(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
result = connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
assert result is None
assert connector._canvas_client is not None
def test_canvas_client_raises_without_credentials(self) -> None:
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
with pytest.raises(ConnectorMissingCredentialError):
_ = connector.canvas_client
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_invalid_token(self, mock_requests: MagicMock) -> None:
self._assert_load_credentials_raises(401, CredentialExpiredError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_load_credentials_insufficient_permissions(
self, mock_requests: MagicMock
) -> None:
self._assert_load_credentials_raises(
403, InsufficientPermissionsError, mock_requests
)
# ---------------------------------------------------------------------------
# CanvasConnector — URL normalization
# ---------------------------------------------------------------------------
class TestConnectorUrlNormalization:
def test_strips_api_v1_suffix(self) -> None:
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/api/v1")
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
def test_strips_trailing_slash(self) -> None:
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/")
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
def test_no_change_for_clean_url(self) -> None:
connector = _build_connector(base_url=FAKE_BASE_URL)
result = connector.canvas_base_url
expected = FAKE_BASE_URL
assert result == expected
# ---------------------------------------------------------------------------
# CanvasConnector — document conversion
# ---------------------------------------------------------------------------
class TestDocumentConversion:
def setup_method(self) -> None:
self.connector = _build_connector()
def test_convert_page_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasPage
page = CanvasPage(
page_id=10,
url="syllabus",
title="Syllabus",
body="<p>Welcome</p>",
created_at="2025-01-15T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_page_to_document(page)
expected_id = "canvas-page-1-10"
expected_metadata = {"course_id": "1", "type": "page"}
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Syllabus"
assert doc.metadata == expected_metadata
assert doc.sections[0].link is not None
assert f"{FAKE_BASE_URL}/courses/1/pages/syllabus" in doc.sections[0].link
assert doc.doc_updated_at == expected_updated_at
def test_convert_page_without_body(self) -> None:
from onyx.connectors.canvas.connector import CanvasPage
page = CanvasPage(
page_id=11,
url="empty-page",
title="Empty Page",
body=None,
created_at="2025-01-15T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_page_to_document(page)
section_text = doc.sections[0].text
assert section_text is not None
assert "Empty Page" in section_text
assert "<p>" not in section_text
def test_convert_assignment_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasAssignment
assignment = CanvasAssignment(
id=20,
name="Homework 1",
description="<p>Solve these</p>",
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/20",
course_id=1,
created_at="2025-01-20T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
due_at="2025-02-01T23:59:00Z",
)
doc = self.connector._convert_assignment_to_document(assignment)
expected_id = "canvas-assignment-1-20"
expected_due_text = "Due: February 01, 2025 23:59 UTC"
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Homework 1"
assert doc.sections[0].text is not None
assert expected_due_text in doc.sections[0].text
def test_convert_assignment_without_description(self) -> None:
from onyx.connectors.canvas.connector import CanvasAssignment
assignment = CanvasAssignment(
id=21,
name="Quiz 1",
description=None,
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/21",
course_id=1,
created_at="2025-01-20T00:00:00Z",
updated_at="2025-06-01T12:00:00Z",
due_at=None,
)
doc = self.connector._convert_assignment_to_document(assignment)
section_text = doc.sections[0].text
assert section_text is not None
assert "Quiz 1" in section_text
assert "Due:" not in section_text
def test_convert_announcement_to_document(self) -> None:
from onyx.connectors.canvas.connector import CanvasAnnouncement
announcement = CanvasAnnouncement(
id=30,
title="Class Cancelled",
message="<p>No class today</p>",
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/30",
posted_at="2025-06-01T12:00:00Z",
course_id=1,
)
doc = self.connector._convert_announcement_to_document(announcement)
expected_id = "canvas-announcement-1-30"
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
assert doc.id == expected_id
assert doc.source == DocumentSource.CANVAS
assert doc.semantic_identifier == "Class Cancelled"
assert doc.doc_updated_at == expected_updated_at
def test_convert_announcement_without_posted_at(self) -> None:
from onyx.connectors.canvas.connector import CanvasAnnouncement
announcement = CanvasAnnouncement(
id=31,
title="TBD Announcement",
message=None,
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/31",
posted_at=None,
course_id=1,
)
doc = self.connector._convert_announcement_to_document(announcement)
assert doc.doc_updated_at is None
# ---------------------------------------------------------------------------
# CanvasConnector — validate_connector_settings
# ---------------------------------------------------------------------------
class TestValidateConnectorSettings:
def _assert_validate_raises(
self,
status_code: int,
expected_error: type[Exception],
mock_requests: MagicMock,
) -> None:
"""Helper: assert validate_connector_settings raises expected_error."""
success_resp = _mock_response(json_data=[_mock_course()])
fail_resp = _mock_response(status_code, {})
mock_requests.get.side_effect = [success_resp, fail_resp]
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
with pytest.raises(expected_error):
connector.validate_connector_settings()
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_success(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
connector = _build_connector()
connector.validate_connector_settings() # should not raise
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_expired_credential(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(401, CredentialExpiredError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
@patch("onyx.connectors.canvas.client.rl_requests")
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
# ---------------------------------------------------------------------------
# _list_* pagination tests
# ---------------------------------------------------------------------------
class TestListCourses:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_course(1), _mock_course(2, "CS201", "Data Structures")]
)
connector = _build_connector()
result = connector._list_courses()
assert len(result) == 2
assert result[0].id == 1
assert result[1].id == 2
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_courses()
assert result == []
class TestListPages:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_page(10), _mock_page(11, "Notes")]
)
connector = _build_connector()
result = connector._list_pages(course_id=1)
assert len(result) == 2
assert result[0].page_id == 10
assert result[1].page_id == 11
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_pages(course_id=1)
assert result == []
class TestListAssignments:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_assignment(20), _mock_assignment(21, "Quiz 1")]
)
connector = _build_connector()
result = connector._list_assignments(course_id=1)
assert len(result) == 2
assert result[0].id == 20
assert result[1].id == 21
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_assignments(course_id=1)
assert result == []
class TestListAnnouncements:
@patch("onyx.connectors.canvas.client.rl_requests")
def test_single_page(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(
json_data=[_mock_announcement(30), _mock_announcement(31, "Update")]
)
connector = _build_connector()
result = connector._list_announcements(course_id=1)
assert len(result) == 2
assert result[0].id == 30
assert result[1].id == 31
@patch("onyx.connectors.canvas.client.rl_requests")
def test_empty_response(self, mock_requests: MagicMock) -> None:
mock_requests.get.return_value = _mock_response(json_data=[])
connector = _build_connector()
result = connector._list_announcements(course_id=1)
assert result == []

View File

@@ -1,6 +1,6 @@
"""Tests for memory tool streaming packet emissions."""
from queue import Queue
import queue
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
@pytest.fixture
def emitter() -> Emitter:
bus: Queue = Queue()
return Emitter(bus)
def emitter_queue() -> queue.Queue:
return queue.Queue()
@pytest.fixture
def emitter(emitter_queue: queue.Queue) -> Emitter:
return Emitter(merged_queue=emitter_queue)
@pytest.fixture
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
def test_emit_start_emits_memory_tool_start_packet(
self,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
) -> None:
memory_tool.emit_start(placement)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolStart)
assert packet.placement == placement
assert packet.placement is not None
assert packet.placement.turn_index == placement.turn_index
assert packet.placement.tab_index == placement.tab_index
assert packet.placement.model_index == 0 # emitter stamps model_index=0
def test_emit_start_with_different_placement(
self,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
) -> None:
placement = Placement(turn_index=2, tab_index=1)
memory_tool.emit_start(placement)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert packet.placement.turn_index == 2
assert packet.placement.tab_index == 1
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
memory="User prefers Python",
)
# The delta packet should be in the queue
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers Python"
assert packet.obj.operation == "add"
assert packet.obj.memory_id is None
assert packet.obj.index is None
assert packet.placement == placement
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
def test_run_emits_delta_for_update_operation(
self,
mock_process: MagicMock,
memory_tool: MemoryTool,
emitter: Emitter,
emitter_queue: queue.Queue,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
) -> None:
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
memory="User prefers light mode",
)
packet = emitter.bus.get_nowait()
_key, packet = emitter_queue.get_nowait()
assert isinstance(packet.obj, MemoryToolDelta)
assert packet.obj.memory_text == "User prefers light mode"
assert packet.obj.operation == "update"

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.39
version: 0.4.38
appVersion: latest
annotations:
category: Productivity

File diff suppressed because it is too large Load Diff

View File

@@ -1,77 +0,0 @@
{{- if and .Values.monitoring.serviceMonitors.enabled .Values.vectorDB.enabled }}
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_monitoring.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- if gt (int .Values.celery_worker_docfetching.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_docfetching.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- if gt (int .Values.celery_worker_docprocessing.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_docprocessing.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- end }}

View File

@@ -1,15 +0,0 @@
{{- if .Values.monitoring.grafana.dashboards.enabled }}
---
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ include "onyx.fullname" . }}-indexing-pipeline-dashboard
labels:
{{- include "onyx.labels" . | nindent 4 }}
grafana_dashboard: "1"
annotations:
grafana_folder: "Onyx"
data:
onyx-indexing-pipeline.json: |
{{- .Files.Get "dashboards/indexing-pipeline.json" | nindent 4 }}
{{- end }}

View File

@@ -256,20 +256,6 @@ tooling:
# -- Which client binary to call; change if your image uses a non-default path.
psqlBinary: psql
monitoring:
grafana:
dashboards:
# -- Set to true to deploy Grafana dashboard ConfigMaps for the Onyx indexing pipeline.
# Requires kube-prometheus-stack (or equivalent) with the Grafana sidecar enabled and watching this namespace.
# The sidecar must be configured with label selector: grafana_dashboard=1
enabled: false
serviceMonitors:
# -- Set to true to deploy ServiceMonitor resources for Celery worker metrics endpoints.
# Requires the Prometheus Operator CRDs (included in kube-prometheus-stack).
# Use `labels` to match your Prometheus CR's serviceMonitorSelector (e.g. release: onyx-monitoring).
enabled: false
labels: {}
serviceAccount:
# Specifies whether a service account should be created
create: false

View File

@@ -19,10 +19,6 @@ module "eks" {
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
enable_cluster_creator_admin_permissions = true
# Control plane logging
cluster_enabled_log_types = var.cluster_enabled_log_types
cloudwatch_log_group_retention_in_days = var.cloudwatch_log_group_retention_in_days
eks_managed_node_group_defaults = {
ami_type = "AL2023_x86_64_STANDARD"
}

View File

@@ -161,25 +161,3 @@ variable "rds_db_connect_arn" {
description = "Full rds-db:connect ARN to allow (required when enable_rds_iam_for_service_account is true)"
default = null
}
variable "cluster_enabled_log_types" {
type = list(string)
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
validation {
condition = alltrue([for t in var.cluster_enabled_log_types : contains(["api", "audit", "authenticator", "controllerManager", "scheduler"], t)])
error_message = "Each entry must be one of: api, audit, authenticator, controllerManager, scheduler."
}
}
variable "cloudwatch_log_group_retention_in_days" {
type = number
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
default = 30
validation {
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.cloudwatch_log_group_retention_in_days)
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
}
}

View File

@@ -54,9 +54,6 @@ module "postgres" {
password = var.postgres_password
tags = local.merged_tags
enable_rds_iam_auth = var.enable_iam_auth
backup_retention_period = var.postgres_backup_retention_period
backup_window = var.postgres_backup_window
}
module "s3" {
@@ -83,10 +80,6 @@ module "eks" {
public_cluster_enabled = var.public_cluster_enabled
private_cluster_enabled = var.private_cluster_enabled
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
# Control plane logging
cluster_enabled_log_types = var.eks_cluster_enabled_log_types
cloudwatch_log_group_retention_in_days = var.eks_cloudwatch_log_group_retention_in_days
}
module "waf" {

View File

@@ -250,34 +250,3 @@ variable "opensearch_subnet_ids" {
description = "Subnet IDs for OpenSearch. If empty, uses first 3 private subnets."
default = []
}
# RDS Backup Configuration
variable "postgres_backup_retention_period" {
type = number
description = "Number of days to retain automated RDS backups (0 to disable)"
default = 7
}
variable "postgres_backup_window" {
type = string
description = "Preferred UTC time window for automated RDS backups (hh24:mi-hh24:mi)"
default = "03:00-04:00"
}
# EKS Control Plane Logging
variable "eks_cluster_enabled_log_types" {
type = list(string)
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
}
variable "eks_cloudwatch_log_group_retention_in_days" {
type = number
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
default = 30
validation {
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.eks_cloudwatch_log_group_retention_in_days)
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
}
}

View File

@@ -44,56 +44,5 @@ resource "aws_db_instance" "this" {
publicly_accessible = false
deletion_protection = true
storage_encrypted = true
# Automated backups
backup_retention_period = var.backup_retention_period
backup_window = var.backup_window
tags = var.tags
}
# CloudWatch alarm for CPU utilization monitoring
resource "aws_cloudwatch_metric_alarm" "cpu_utilization" {
alarm_name = "${var.identifier}-cpu-utilization"
alarm_description = "RDS CPU utilization for ${var.identifier}"
comparison_operator = "GreaterThanThreshold"
evaluation_periods = var.cpu_alarm_evaluation_periods
metric_name = "CPUUtilization"
namespace = "AWS/RDS"
period = var.cpu_alarm_period
statistic = "Average"
threshold = var.cpu_alarm_threshold
treat_missing_data = "missing"
alarm_actions = var.alarm_actions
ok_actions = var.alarm_actions
dimensions = {
DBInstanceIdentifier = aws_db_instance.this.identifier
}
tags = var.tags
}
# CloudWatch alarm for freeable memory monitoring
resource "aws_cloudwatch_metric_alarm" "freeable_memory" {
alarm_name = "${var.identifier}-freeable-memory"
alarm_description = "RDS freeable memory for ${var.identifier}"
comparison_operator = "LessThanThreshold"
evaluation_periods = var.memory_alarm_evaluation_periods
metric_name = "FreeableMemory"
namespace = "AWS/RDS"
period = var.memory_alarm_period
statistic = "Average"
threshold = var.memory_alarm_threshold
treat_missing_data = "missing"
alarm_actions = var.alarm_actions
ok_actions = var.alarm_actions
dimensions = {
DBInstanceIdentifier = aws_db_instance.this.identifier
}
tags = var.tags
tags = var.tags
}

View File

@@ -67,98 +67,3 @@ variable "enable_rds_iam_auth" {
description = "Enable AWS IAM database authentication for this RDS instance"
default = false
}
variable "backup_retention_period" {
type = number
description = "Number of days to retain automated backups (0 to disable)"
default = 7
validation {
condition = var.backup_retention_period >= 0 && var.backup_retention_period <= 35
error_message = "backup_retention_period must be between 0 and 35 (AWS RDS limit)."
}
}
variable "backup_window" {
type = string
description = "Preferred UTC time window for automated backups (hh24:mi-hh24:mi)"
default = "03:00-04:00"
validation {
condition = can(regex("^([01]\\d|2[0-3]):[0-5]\\d-([01]\\d|2[0-3]):[0-5]\\d$", var.backup_window))
error_message = "backup_window must be in hh24:mi-hh24:mi format (e.g. \"03:00-04:00\")."
}
}
# CloudWatch CPU alarm configuration
variable "cpu_alarm_threshold" {
type = number
description = "CPU utilization percentage threshold for the CloudWatch alarm"
default = 80
validation {
condition = var.cpu_alarm_threshold >= 0 && var.cpu_alarm_threshold <= 100
error_message = "cpu_alarm_threshold must be between 0 and 100 (percentage)."
}
}
variable "cpu_alarm_evaluation_periods" {
type = number
description = "Number of consecutive periods the threshold must be breached before alarming"
default = 3
validation {
condition = var.cpu_alarm_evaluation_periods >= 1
error_message = "cpu_alarm_evaluation_periods must be at least 1."
}
}
variable "cpu_alarm_period" {
type = number
description = "Period in seconds over which the CPU metric is evaluated"
default = 300
validation {
condition = var.cpu_alarm_period >= 60 && var.cpu_alarm_period % 60 == 0
error_message = "cpu_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
}
}
variable "memory_alarm_threshold" {
type = number
description = "Freeable memory threshold in bytes. Alarm fires when memory drops below this value."
default = 256000000 # 256 MB
validation {
condition = var.memory_alarm_threshold > 0
error_message = "memory_alarm_threshold must be greater than 0."
}
}
variable "memory_alarm_evaluation_periods" {
type = number
description = "Number of consecutive periods the threshold must be breached before alarming"
default = 3
validation {
condition = var.memory_alarm_evaluation_periods >= 1
error_message = "memory_alarm_evaluation_periods must be at least 1."
}
}
variable "memory_alarm_period" {
type = number
description = "Period in seconds over which the freeable memory metric is evaluated"
default = 300
validation {
condition = var.memory_alarm_period >= 60 && var.memory_alarm_period % 60 == 0
error_message = "memory_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
}
}
variable "alarm_actions" {
type = list(string)
description = "List of ARNs to notify when the alarm transitions state (e.g. SNS topic ARNs)"
default = []
}

View File

@@ -1,349 +0,0 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": { "type": "grafana", "uid": "-- Grafana --" },
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"id": null,
"links": [],
"liveNow": true,
"panels": [
{
"title": "Client-Side Search Latency (P50 / P95 / P99)",
"description": "End-to-end latency as measured by the Python client, including network round-trip and serialization overhead.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 0 },
"id": 1,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "seconds",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "dashed" }
},
"thresholds": {
"mode": "absolute",
"steps": [
{ "color": "green", "value": null },
{ "color": "yellow", "value": 0.5 },
{ "color": "red", "value": 2.0 }
]
},
"unit": "s",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "P50",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "P95",
"refId": "B"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "P99",
"refId": "C"
}
]
},
{
"title": "Server-Side Search Latency (P50 / P95 / P99)",
"description": "OpenSearch server-side execution time from the 'took' field in the response. Does not include network or client-side overhead.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 0 },
"id": 2,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "seconds",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "dashed" }
},
"thresholds": {
"mode": "absolute",
"steps": [
{ "color": "green", "value": null },
{ "color": "yellow", "value": 0.5 },
{ "color": "red", "value": 2.0 }
]
},
"unit": "s",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "P50",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "P95",
"refId": "B"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "P99",
"refId": "C"
}
]
},
{
"title": "Client-Side Latency by Search Type (P95)",
"description": "P95 client-side latency broken down by search type (hybrid, keyword, semantic, random, doc_id_retrieval).",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 10 },
"id": 3,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "seconds",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "off" }
},
"unit": "s",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.95, sum by (search_type, le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "{{ search_type }}",
"refId": "A"
}
]
},
{
"title": "Search Throughput by Type",
"description": "Searches per second broken down by search type.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 10 },
"id": 4,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "searches/s",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "normal" },
"thresholdsStyle": { "mode": "off" }
},
"unit": "ops",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (search_type) (rate(onyx_opensearch_search_total[5m]))",
"legendFormat": "{{ search_type }}",
"refId": "A"
}
]
},
{
"title": "Concurrent Searches In Progress",
"description": "Number of OpenSearch searches currently in flight, broken down by search type. Summed across all instances.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 20 },
"id": 5,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "searches",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "normal" },
"thresholdsStyle": { "mode": "off" }
},
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "sum by (search_type) (onyx_opensearch_searches_in_progress)",
"legendFormat": "{{ search_type }}",
"refId": "A"
}
]
},
{
"title": "Client vs Server Latency Overhead (P50)",
"description": "Difference between client-side and server-side P50 latency. Reveals network, serialization, and untracked OpenSearch overhead.",
"type": "timeseries",
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 20 },
"id": 6,
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"fieldConfig": {
"defaults": {
"color": { "mode": "palette-classic" },
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisLabel": "seconds",
"axisPlacement": "auto",
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": { "type": "linear" },
"showPoints": "never",
"spanNulls": false,
"stacking": { "group": "A", "mode": "none" },
"thresholdsStyle": { "mode": "off" }
},
"unit": "s",
"min": 0
},
"overrides": []
},
"targets": [
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m]))) - histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "Client - Server overhead (P50)",
"refId": "A"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
"legendFormat": "Client P50",
"refId": "B"
},
{
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
"legendFormat": "Server P50",
"refId": "C"
}
]
}
],
"refresh": "5s",
"schemaVersion": 37,
"style": "dark",
"tags": ["onyx", "opensearch", "search", "latency"],
"templating": {
"list": [
{
"current": {
"text": "Prometheus",
"value": "prometheus"
},
"includeAll": false,
"name": "DS_PROMETHEUS",
"options": [],
"query": "prometheus",
"refresh": 1,
"type": "datasource"
}
]
},
"time": { "from": "now-60m", "to": "now" },
"timepicker": {
"refresh_intervals": ["5s", "10s", "30s", "1m"]
},
"timezone": "",
"title": "Onyx OpenSearch Search Latency",
"uid": "onyx-opensearch-search-latency",
"version": 0,
"weekStart": ""
}

View File

@@ -73,17 +73,11 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
ARG SENTRY_RELEASE
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
# Add NODE_OPTIONS argument
ARG NODE_OPTIONS
# SENTRY_AUTH_TOKEN is injected via BuildKit secret mount so it is never written
# to any image layer, build cache, or registry manifest.
# Use NODE_OPTIONS in the build command
RUN --mount=type=secret,id=sentry_auth_token,env=SENTRY_AUTH_TOKEN \
NODE_OPTIONS="${NODE_OPTIONS}" npx next build
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
# Step 2. Production image, copy all the files and run next
FROM base AS runner
@@ -156,9 +150,6 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
ARG SENTRY_RELEASE
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION}

View File

@@ -24,7 +24,6 @@ type TextFont =
| "secondary-body"
| "secondary-action"
| "secondary-mono"
| "secondary-mono-label"
| "figure-small-label"
| "figure-small-value"
| "figure-keystroke";
@@ -89,7 +88,6 @@ const FONT_CONFIG: Record<TextFont, string> = {
"secondary-body": "font-secondary-body",
"secondary-action": "font-secondary-action",
"secondary-mono": "font-secondary-mono",
"secondary-mono-label": "font-secondary-mono-label",
"figure-small-label": "font-figure-small-label",
"figure-small-value": "font-figure-small-value",
"figure-keystroke": "font-figure-keystroke",

View File

@@ -8,7 +8,6 @@ import * as Sentry from "@sentry/nextjs";
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
Sentry.init({
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
release: process.env.SENTRY_RELEASE,
// Only capture unhandled exceptions
tracesSampleRate: 0,
debug: false,

View File

@@ -7,7 +7,6 @@ import * as Sentry from "@sentry/nextjs";
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
Sentry.init({
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
release: process.env.SENTRY_RELEASE,
// Setting this option to true will print useful information to the console while you're setting up Sentry.
debug: false,

View File

@@ -1 +1,7 @@
export { default } from "@/refresh-pages/admin/CodeInterpreterPage";
"use client";
import CodeInterpreterPage from "@/refresh-pages/admin/CodeInterpreterPage";
export default function Page() {
return <CodeInterpreterPage />;
}

View File

@@ -0,0 +1,23 @@
"use client";
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
import { ImageProvider } from "@/app/admin/configuration/image-generation/constants";
import { LLMProviderView } from "@/interfaces/llm";
import { ImageGenerationConfigView } from "@/lib/configuration/imageConfigurationService";
import { getImageGenForm } from "./forms";
interface Props {
modal: ModalCreationInterface;
imageProvider: ImageProvider;
existingProviders: LLMProviderView[];
existingConfig?: ImageGenerationConfigView;
onSuccess: () => void;
}
/**
* Modal for creating/editing image generation configurations.
* Routes to provider-specific forms based on imageProvider.provider_name.
*/
export default function ImageGenerationConnectionModal(props: Props) {
return <>{getImageGenForm(props)}</>;
}

View File

@@ -2,6 +2,7 @@
import { useState, useMemo, useEffect } from "react";
import useSWR from "swr";
import { Select } from "@/refresh-components/cards";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
@@ -10,39 +11,24 @@ import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
import {
IMAGE_PROVIDER_GROUPS,
ImageProvider,
} from "@/refresh-pages/admin/ImageGenerationPage/constants";
} from "@/app/admin/configuration/image-generation/constants";
import ImageGenerationConnectionModal from "@/app/admin/configuration/image-generation/ImageGenerationConnectionModal";
import {
ImageGenerationConfigView,
setDefaultImageGenerationConfig,
unsetDefaultImageGenerationConfig,
deleteImageGenerationConfig,
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
} from "@/lib/configuration/imageConfigurationService";
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
import Message from "@/refresh-components/messages/Message";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { Button, SelectCard, Text } from "@opal/components";
import { Content, CardHeaderLayout } from "@opal/layouts";
import { Hoverable } from "@opal/core";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgCheckSquare,
SvgSettings,
SvgSlash,
SvgUnplug,
} from "@opal/icons";
import { Button, Text } from "@opal/components";
import { SvgSlash, SvgUnplug } from "@opal/icons";
import { markdown } from "@opal/utils";
import { getImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms";
const NO_DEFAULT_VALUE = "__none__";
const STATUS_TO_STATE = {
disconnected: "empty",
connected: "filled",
selected: "selected",
} as const;
export default function ImageGenerationContent() {
const {
data: llmProviderResponse,
@@ -212,13 +198,16 @@ export default function ImageGenerationContent() {
return (
<>
<div className="flex flex-col gap-4">
<Content
title="Image Generation Model"
description="Select a model to generate images in chat."
sizePreset="main-content"
variant="section"
/>
<div className="flex flex-col gap-6">
{/* Section Header */}
<div className="flex flex-col gap-0.5">
<Text font="main-content-emphasis" color="text-05">
Image Generation Model
</Text>
<Text font="secondary-body" color="text-03">
Select a model to generate images in chat.
</Text>
</div>
{connectedProviderIds.size === 0 && (
<Message
@@ -234,111 +223,32 @@ export default function ImageGenerationContent() {
{/* Provider Groups */}
{IMAGE_PROVIDER_GROUPS.map((group) => (
<div key={group.name} className="flex flex-col gap-2">
<Content title={group.name} sizePreset="secondary" variant="body" />
{group.providers.map((provider) => {
const status = getStatus(provider);
const isDisconnected = status === "disconnected";
const isConnected = status === "connected";
const isSelected = status === "selected";
return (
<Hoverable.Root
<Text font="secondary-body" color="text-03">
{group.name}
</Text>
<div className="flex flex-col gap-2">
{group.providers.map((provider) => (
<Select
key={provider.image_provider_id}
group="image-gen/ProviderCard"
>
<SelectCard
variant="select-card"
state={STATUS_TO_STATE[status]}
sizeVariant="lg"
aria-label={`image-gen-provider-${provider.image_provider_id}`}
onClick={
isDisconnected
? () => handleConnect(provider)
: isSelected
? () => handleDeselect(provider)
: undefined
}
>
<CardHeaderLayout
sizePreset="main-ui"
variant="section"
icon={() => (
<ProviderIcon
provider={provider.provider_name}
size={16}
/>
)}
title={provider.title}
description={provider.description}
rightChildren={
isDisconnected ? (
<Button
prominence="tertiary"
rightIcon={SvgArrowExchange}
onClick={(e) => {
e.stopPropagation();
handleConnect(provider);
}}
>
Connect
</Button>
) : isConnected ? (
<Button
prominence="tertiary"
rightIcon={SvgArrowRightCircle}
onClick={(e) => {
e.stopPropagation();
handleSelect(provider);
}}
>
Set as Default
</Button>
) : isSelected ? (
<div className="p-2">
<Content
title="Current Default"
sizePreset="main-ui"
variant="section"
icon={SvgCheckSquare}
/>
</div>
) : undefined
}
bottomRightChildren={
!isDisconnected ? (
<div className="flex flex-row px-1 pb-1">
<Hoverable.Item group="image-gen/ProviderCard">
<Button
icon={SvgUnplug}
tooltip="Disconnect"
aria-label={`Disconnect ${provider.title}`}
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
setDisconnectProvider(provider);
}}
size="md"
/>
</Hoverable.Item>
<Button
icon={SvgSettings}
tooltip="Edit"
aria-label={`Edit ${provider.title}`}
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
handleEdit(provider);
}}
size="md"
/>
</div>
) : undefined
}
/>
</SelectCard>
</Hoverable.Root>
);
})}
aria-label={`image-gen-provider-${provider.image_provider_id}`}
icon={() => (
<ProviderIcon provider={provider.provider_name} size={18} />
)}
title={provider.title}
description={provider.description}
status={getStatus(provider)}
onConnect={() => handleConnect(provider)}
onSelect={() => handleSelect(provider)}
onDeselect={() => handleDeselect(provider)}
onEdit={() => handleEdit(provider)}
onDisconnect={
getStatus(provider) !== "disconnected"
? () => setDisconnectProvider(provider)
: undefined
}
/>
))}
</div>
</div>
))}
</div>
@@ -447,13 +357,13 @@ export default function ImageGenerationContent() {
{activeProvider && (
<modal.Provider>
{getImageGenForm({
modal: modal,
imageProvider: activeProvider,
existingProviders: llmProviders,
existingConfig: editConfig || undefined,
onSuccess: handleModalSuccess,
})}
<ImageGenerationConnectionModal
modal={modal}
imageProvider={activeProvider}
existingProviders={llmProviders}
existingConfig={editConfig || undefined}
onSuccess={handleModalSuccess}
/>
</modal.Provider>
)}
</>

View File

@@ -7,14 +7,14 @@ import { FormField } from "@/refresh-components/form/FormField";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
import {
ImageGenFormBaseProps,
ImageGenFormChildProps,
ImageGenSubmitPayload,
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
} from "./types";
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
import { ImageProvider } from "../constants";
import {
parseAzureTargetUri,
isValidAzureTargetUri,

View File

@@ -10,14 +10,14 @@ import {
createImageGenerationConfig,
updateImageGenerationConfig,
fetchImageGenerationCredentials,
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
} from "@/lib/configuration/imageConfigurationService";
import { APIFormFieldState } from "@/refresh-components/form/types";
import {
ImageGenFormWrapperProps,
ImageGenFormChildProps,
ImageGenSubmitPayload,
FormValues,
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
} from "./types";
import { toast } from "@/hooks/useToast";
export function ImageGenFormWrapper<T extends FormValues>({
@@ -41,6 +41,9 @@ export function ImageGenFormWrapper<T extends FormValues>({
const [errorMessage, setErrorMessage] = useState("");
const [isLoadingCredentials, setIsLoadingCredentials] = useState(false);
// Form reset key for re-initialization
const [formResetKey, setFormResetKey] = useState(0);
// Track merged initial values with fetched credentials
const [mergedInitialValues, setMergedInitialValues] =
useState<T>(initialValues);
@@ -70,6 +73,7 @@ export function ImageGenFormWrapper<T extends FormValues>({
imageProvider
);
setMergedInitialValues((prev) => ({ ...prev, ...credValues }));
setFormResetKey((k) => k + 1);
}
})
.catch((err) => {
@@ -272,6 +276,7 @@ export function ImageGenFormWrapper<T extends FormValues>({
return (
<Formik<T>
key={formResetKey}
initialValues={mergedInitialValues}
onSubmit={handleSubmit}
validationSchema={validationSchema}

View File

@@ -6,14 +6,14 @@ import { FormikField } from "@/refresh-components/form/FormikField";
import { FormField } from "@/refresh-components/form/FormField";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
import {
ImageGenFormBaseProps,
ImageGenFormChildProps,
ImageGenSubmitPayload,
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
} from "./types";
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
import { ImageProvider } from "../constants";
// OpenAI form values - just API key
interface OpenAIFormValues {

View File

@@ -6,14 +6,14 @@ import { FormField } from "@/refresh-components/form/FormField";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputFile from "@/refresh-components/inputs/InputFile";
import InlineExternalLink from "@/refresh-components/InlineExternalLink";
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
import {
ImageGenFormBaseProps,
ImageGenFormChildProps,
ImageGenSubmitPayload,
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
} from "./types";
import { ImageProvider } from "../constants";
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
const VERTEXAI_PROVIDER_NAME = "vertex_ai";
const VERTEXAI_DEFAULT_LOCATION = "global";

View File

@@ -1,8 +1,8 @@
import React from "react";
import { ImageGenFormBaseProps } from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
import { OpenAIImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/OpenAIImageGenForm";
import { AzureImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/AzureImageGenForm";
import { VertexImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/VertexImageGenForm";
import { ImageGenFormBaseProps } from "./types";
import { OpenAIImageGenForm } from "./OpenAIImageGenForm";
import { AzureImageGenForm } from "./AzureImageGenForm";
import { VertexImageGenForm } from "./VertexImageGenForm";
/**
* Factory function that routes to the correct provider-specific form

View File

@@ -0,0 +1,5 @@
export * from "./types";
export { ImageGenFormWrapper } from "./ImageGenFormWrapper";
export { OpenAIImageGenForm } from "./OpenAIImageGenForm";
export { AzureImageGenForm } from "./AzureImageGenForm";
export { getImageGenForm } from "./getImageGenForm";

View File

@@ -1,10 +1,10 @@
import { FormikProps } from "formik";
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
import { ImageProvider } from "../constants";
import { LLMProviderView } from "@/interfaces/llm";
import {
ImageGenerationConfigView,
ImageGenerationCredentials,
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
} from "@/lib/configuration/imageConfigurationService";
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
import { APIFormFieldState } from "@/refresh-components/form/types";

View File

@@ -1 +1,22 @@
export { default } from "@/refresh-pages/admin/ImageGenerationPage";
"use client";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import ImageGenerationContent from "./ImageGenerationContent";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
const route = ADMIN_ROUTES.IMAGE_GENERATION;
export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title={route.title}
description="Settings for in-chat image generation."
/>
<SettingsLayouts.Body>
<ImageGenerationContent />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1 +1,7 @@
export { default } from "@/refresh-pages/admin/LLMConfigurationPage";
"use client";
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
export default function Page() {
return <LLMConfigurationPage />;
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,6 @@
"use client";
import { Form, Formik } from "formik";
import { mutate } from "swr";
import { SWR_KEYS } from "@/lib/swr-keys";
import * as Yup from "yup";
import { toast } from "@/hooks/useToast";
import {
@@ -121,10 +119,6 @@ export const DocumentSetCreationForm = ({
? "Successfully updated document set!"
: "Successfully created document set!"
);
await Promise.all([
mutate(SWR_KEYS.documentSets),
mutate(SWR_KEYS.documentSetsEditable),
]);
onClose();
} else {
const errorMsg = await response.text();

View File

@@ -1,16 +1,17 @@
import { errorHandlingFetcher } from "@/lib/fetcher";
import { DocumentSetSummary } from "@/lib/types";
import useSWR, { mutate } from "swr";
import { SWR_KEYS } from "@/lib/swr-keys";
const DOCUMENT_SETS_URL = "/api/manage/document-set";
const GET_EDITABLE_DOCUMENT_SETS_URL =
"/api/manage/document-set?get_editable=true";
export function refreshDocumentSets() {
mutate(SWR_KEYS.documentSets);
mutate(DOCUMENT_SETS_URL);
}
export function useDocumentSets(getEditable: boolean = false) {
const url = getEditable
? SWR_KEYS.documentSetsEditable
: SWR_KEYS.documentSets;
const url = getEditable ? GET_EDITABLE_DOCUMENT_SETS_URL : DOCUMENT_SETS_URL;
const swrResponse = useSWR<DocumentSetSummary[]>(url, errorHandlingFetcher, {
refreshInterval: 5000, // 5 seconds

View File

@@ -182,7 +182,8 @@ export async function* sendMessage({
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
const data = await response.json().catch(() => ({}));
throw new Error(data.detail ?? `HTTP error! status: ${response.status}`);
}
yield* handleSSEStream<PacketType>(response, signal);

View File

@@ -11,7 +11,6 @@ import Text from "@/refresh-components/texts/Text";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { ThreeDotsLoader } from "@/components/Loading";
import { ChatSessionMinimal } from "@/app/ee/admin/performance/usage/types";
import { Section } from "@/layouts/general-layouts";
import { timestampToReadableDate } from "@/lib/dateUtils";
import { Dispatch, SetStateAction, useCallback, useState } from "react";
import { Feedback, TaskStatus } from "@/lib/types";
@@ -102,32 +101,34 @@ function SelectFeedbackType({
onValueChange: (value: Feedback | "all") => void;
}) {
return (
<Section alignItems="start" gap={0.25}>
<Text as="p" className="font-medium">
<div>
<Text as="p" className="my-auto mr-2 font-medium mb-1">
Feedback Type
</Text>
<InputSelect
value={value}
onValueChange={onValueChange as (value: string) => void}
>
<InputSelect.Trigger />
<div className="max-w-sm space-y-6">
<InputSelect
value={value}
onValueChange={onValueChange as (value: string) => void}
>
<InputSelect.Trigger />
<InputSelect.Content>
<InputSelect.Item value="all" icon={SvgMinusCircle}>
Any
</InputSelect.Item>
<InputSelect.Item value="like" icon={SvgThumbsUp}>
Like
</InputSelect.Item>
<InputSelect.Item value="dislike" icon={SvgThumbsDown}>
Dislike
</InputSelect.Item>
<InputSelect.Item value="mixed" icon={SvgMinus}>
Mixed
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Section>
<InputSelect.Content>
<InputSelect.Item value="all" icon={SvgMinusCircle}>
Any
</InputSelect.Item>
<InputSelect.Item value="like" icon={SvgThumbsUp}>
Like
</InputSelect.Item>
<InputSelect.Item value="dislike" icon={SvgThumbsDown}>
Dislike
</InputSelect.Item>
<InputSelect.Item value="mixed" icon={SvgMinus}>
Mixed
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</div>
</div>
);
}
@@ -184,61 +185,60 @@ function PreviousQueryHistoryExportsModal({
onClose={() => setShowModal(false)}
/>
<Modal.Body>
<Table>
<TableHeader>
<TableRow>
<TableHead>Generated At</TableHead>
<TableHead>Start Range</TableHead>
<TableHead>End Range</TableHead>
<TableHead>Status</TableHead>
<TableHead>Download</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{paginatedTasks.map((task, index) => (
<TableRow key={index}>
<TableCell>
{humanReadableFormatWithTime(task.startTime)}
</TableCell>
<TableCell>{task.start.toDateString()}</TableCell>
<TableCell>{task.end.toDateString()}</TableCell>
<TableCell>
<ExportBadge status={task.status} />
</TableCell>
<TableCell>
<Button
variant="default"
prominence="tertiary"
icon={SvgDownloadCloud}
size="sm"
disabled={task.status !== "SUCCESS"}
tooltip={
task.status !== "SUCCESS"
? "Export is not yet ready"
: undefined
}
href={
task.status === "SUCCESS"
? withRequestId(
<div className="flex flex-col w-full">
<div className="flex flex-1">
<Table>
<TableHeader>
<TableRow>
<TableHead>Generated At</TableHead>
<TableHead>Start Range</TableHead>
<TableHead>End Range</TableHead>
<TableHead>Status</TableHead>
<TableHead>Download</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{paginatedTasks.map((task, index) => (
<TableRow key={index}>
<TableCell>
{humanReadableFormatWithTime(task.startTime)}
</TableCell>
<TableCell>{task.start.toDateString()}</TableCell>
<TableCell>{task.end.toDateString()}</TableCell>
<TableCell>
<ExportBadge status={task.status} />
</TableCell>
<TableCell>
{task.status === "SUCCESS" ? (
<a
className="flex justify-center"
href={withRequestId(
DOWNLOAD_QUERY_HISTORY_URL,
task.taskId
)
: undefined
}
/>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
)}
>
<SvgDownloadCloud className="h-4 w-4 text-action-link-05" />
</a>
) : (
<SvgDownloadCloud className="h-4 w-4 text-action-link-05 opacity-20" />
)}
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
<Section>
<PageSelector
currentPage={taskPage}
totalPages={totalTaskPages}
onPageChange={setTaskPage}
/>
</Section>
<div className="flex mt-3">
<div className="mx-auto">
<PageSelector
currentPage={taskPage}
totalPages={totalTaskPages}
onPageChange={setTaskPage}
/>
</div>
</div>
</div>
</Modal.Body>
</Modal.Content>
</Modal>
@@ -330,48 +330,48 @@ export function QueryHistoryTable() {
</div>
</div>
<Separator />
<Section>
<Table className="mt-5">
<TableHeader>
<Table className="mt-5">
<TableHeader>
<TableRow>
<TableHead>First User Message</TableHead>
<TableHead>First AI Response</TableHead>
<TableHead>Feedback</TableHead>
<TableHead>User</TableHead>
<TableHead>Persona</TableHead>
<TableHead>Date</TableHead>
</TableRow>
</TableHeader>
{isLoading ? (
<TableBody>
<TableRow>
<TableHead>First User Message</TableHead>
<TableHead>First AI Response</TableHead>
<TableHead>Feedback</TableHead>
<TableHead>User</TableHead>
<TableHead>Persona</TableHead>
<TableHead>Date</TableHead>
<TableCell colSpan={6} className="text-center">
<ThreeDotsLoader />
</TableCell>
</TableRow>
</TableHeader>
{isLoading ? (
<TableBody>
<TableRow>
<TableCell colSpan={6} className="text-center">
<ThreeDotsLoader />
</TableCell>
</TableRow>
</TableBody>
) : (
<TableBody>
{chatSessionData?.map((chatSessionMinimal) => (
<QueryHistoryTableRow
key={chatSessionMinimal.id}
chatSessionMinimal={chatSessionMinimal}
/>
))}
</TableBody>
)}
</Table>
</TableBody>
) : (
<TableBody>
{chatSessionData?.map((chatSessionMinimal) => (
<QueryHistoryTableRow
key={chatSessionMinimal.id}
chatSessionMinimal={chatSessionMinimal}
/>
))}
</TableBody>
)}
</Table>
{chatSessionData && (
<Section>
{chatSessionData && (
<div className="mt-3 flex">
<div className="mx-auto">
<PageSelector
totalPages={totalPages}
currentPage={currentPage}
onPageChange={goToPage}
/>
</Section>
)}
</Section>
</div>
</div>
)}
</CardSection>
{showModal && (

View File

@@ -330,14 +330,6 @@
letter-spacing: 0px;
}
.font-secondary-mono-label {
font-family: var(--font-dm-mono);
font-size: 12px;
font-weight: 500;
line-height: 16px;
letter-spacing: 0px;
}
/* FIGURE STYLES */
.font-figure-small-label {

View File

@@ -1,7 +1,6 @@
"use client";
import useSWR from "swr";
import { SWR_KEYS } from "@/lib/swr-keys";
import {
UserSpecificAgentPreference,
UserSpecificAgentPreferences,
@@ -10,6 +9,7 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
import { useCallback } from "react";
// TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
const AGENT_PREFERENCES_URL = "/api/user/assistant/preferences";
// TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
const buildUpdateAgentPreferenceUrl = (agentId: number) =>
@@ -21,11 +21,10 @@ const buildUpdateAgentPreferenceUrl = (agentId: number) =>
*/
export default function useAgentPreferences() {
const { data, mutate } = useSWR<UserSpecificAgentPreferences>(
SWR_KEYS.agentPreferences,
AGENT_PREFERENCES_URL,
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);

View File

@@ -2,7 +2,6 @@
import useSWR from "swr";
import { useState, useEffect, useMemo, useCallback } from "react";
import { SWR_KEYS } from "@/lib/swr-keys";
import {
MinimalPersonaSnapshot,
FullPersona,
@@ -37,11 +36,10 @@ import useChatSessions from "./useChatSessions";
*/
export function useAgents() {
const { data, error, mutate } = useSWR<MinimalPersonaSnapshot[]>(
SWR_KEYS.personas,
"/api/persona",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);
@@ -78,11 +76,10 @@ export function useAgents() {
*/
export function useAgent(agentId: number | null) {
const { data, error, isLoading, mutate } = useSWR<FullPersona>(
agentId ? SWR_KEYS.persona(agentId) : null,
agentId ? `/api/persona/${agentId}` : null,
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);

View File

@@ -1,6 +1,5 @@
import useSWR from "swr";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { SWR_KEYS } from "@/lib/swr-keys";
interface AuthTypeAPIResponse {
auth_type: string;
@@ -55,12 +54,11 @@ export function useAuthTypeMetadata(): {
error: Error | undefined;
} {
const { data, error, isLoading } = useSWR<AuthTypeMetadata>(
SWR_KEYS.authType,
"/api/auth/type",
fetchAuthTypeMetadata,
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
revalidateIfStale: false,
dedupingInterval: 30_000,
}
);

View File

@@ -3,7 +3,6 @@
import useSWR from "swr";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
/**
* Hook to fetch all available tools from the backend.
@@ -25,12 +24,10 @@ import { SWR_KEYS } from "@/lib/swr-keys";
*/
export function useAvailableTools() {
const { data, error, mutate } = useSWR<ToolSnapshot[]>(
SWR_KEYS.tools,
"/api/tool",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
revalidateOnFocus: true,
}
);

View File

@@ -2,7 +2,6 @@ import useSWR from "swr";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
import {
BillingInformation,
SubscriptionStatus,
@@ -17,15 +16,14 @@ import {
*/
export function useBillingInformation() {
const url = NEXT_PUBLIC_CLOUD_ENABLED
? SWR_KEYS.billingInformationCloud
: SWR_KEYS.billingInformationSelfHosted;
? "/api/tenants/billing-information"
: "/api/admin/billing/billing-information";
const { data, error, mutate, isLoading } = useSWR<
BillingInformation | SubscriptionStatus
>(url, errorHandlingFetcher, {
revalidateOnFocus: false,
revalidateOnReconnect: false,
revalidateIfStale: false,
dedupingInterval: 30000,
shouldRetryOnError: false,
keepPreviousData: true,

View File

@@ -901,6 +901,11 @@ export default function useChatController({
});
}
}
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
// catch block replaces the thinking placeholder with an error message.
if (stack.error) {
throw new Error(stack.error);
}
} catch (e: any) {
console.log("Error:", e);
const errorMsg = e.message;

View File

@@ -10,7 +10,6 @@ import {
import useSWRInfinite from "swr/infinite";
import { ChatSession, ChatSessionSharedStatus } from "@/app/app/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import useAppFocus from "./useAppFocus";
import { useAgents } from "./useAgents";
@@ -147,7 +146,7 @@ export default function useChatSessions(): UseChatSessionsOutput {
// First page — no cursor
if (pageIndex === 0) {
return `${SWR_KEYS.chatSessions}?page_size=${PAGE_SIZE}`;
return `/api/chat/get-user-chat-sessions?page_size=${PAGE_SIZE}`;
}
// Subsequent pages — cursor from the last session of the previous page
@@ -159,7 +158,7 @@ export default function useChatSessions(): UseChatSessionsOutput {
page_size: PAGE_SIZE.toString(),
before: lastSession.time_updated,
});
return `${SWR_KEYS.chatSessions}?${params.toString()}`;
return `/api/chat/get-user-chat-sessions?${params.toString()}`;
};
const { data, error, setSize, mutate } = useSWRInfinite<ChatSessionsResponse>(
@@ -167,7 +166,6 @@ export default function useChatSessions(): UseChatSessionsOutput {
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
revalidateFirstPage: true,
revalidateAll: false,
dedupingInterval: 30000,

View File

@@ -1,7 +1,6 @@
import useSWR, { type KeyedMutator } from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { User } from "@/lib/types";
import { SWR_KEYS } from "@/lib/swr-keys";
/**
* Fetches the current authenticated user via SWR (`/api/me`).
@@ -30,12 +29,11 @@ export function useCurrentUser(): {
userError: (Error & { status?: number }) | undefined;
} {
const { data, mutate, error } = useSWR<User>(
SWR_KEYS.me,
"/api/me",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
revalidateIfStale: false,
dedupingInterval: 30_000,
}
);

View File

@@ -1,42 +0,0 @@
import useSWR from "swr";
import { fetchExecutionLogs } from "@/refresh-pages/admin/HooksPage/svc";
import type { HookExecutionRecord } from "@/refresh-pages/admin/HooksPage/interfaces";
const ONE_HOUR_MS = 60 * 60 * 1000;
const THIRTY_DAYS_MS = 30 * 24 * 60 * 60 * 1000;
interface UseHookExecutionLogsResult {
isLoading: boolean;
error: Error | undefined;
hasRecentErrors: boolean;
recentErrors: HookExecutionRecord[];
olderErrors: HookExecutionRecord[];
}
export function useHookExecutionLogs(
hookId: number,
limit = 10
): UseHookExecutionLogsResult {
const { data, isLoading, error } = useSWR(
["hook-execution-logs", hookId, limit],
() => fetchExecutionLogs(hookId, limit),
{ refreshInterval: 60_000 }
);
const now = Date.now();
const recentErrors =
data?.filter(
(log) => now - new Date(log.created_at).getTime() < ONE_HOUR_MS
) ?? [];
const olderErrors =
data?.filter((log) => {
const age = now - new Date(log.created_at).getTime();
return age >= ONE_HOUR_MS && age < THIRTY_DAYS_MS;
}) ?? [];
const hasRecentErrors = recentErrors.length > 0;
return { isLoading, error, hasRecentErrors, recentErrors, olderErrors };
}

View File

@@ -2,13 +2,13 @@
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
import {
LLMProviderDescriptor,
LLMProviderResponse,
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
/**
* Fetches configured LLM providers accessible to the current user.
@@ -45,14 +45,13 @@ import {
export function useLLMProviders(personaId?: number) {
const url =
personaId !== undefined
? SWR_KEYS.llmProvidersForPersona(personaId)
: SWR_KEYS.llmProviders;
? `/api/llm/persona/${personaId}/providers`
: "/api/llm/provider";
const { data, error, mutate } = useSWR<
LLMProviderResponse<LLMProviderDescriptor>
>(url, errorHandlingFetcher, {
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
});
@@ -89,11 +88,10 @@ export function useLLMProviders(personaId?: number) {
*/
export function useAdminLLMProviders() {
const { data, error, mutate } = useSWR<LLMProviderResponse<LLMProviderView>>(
SWR_KEYS.adminLlmProviders,
LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);
@@ -143,11 +141,12 @@ export function useAdminLLMProviders() {
*/
export function useWellKnownLLMProvider(providerEndpoint: string | null) {
const { data, error, isLoading } = useSWR<WellKnownLLMProviderDescriptor>(
providerEndpoint ? SWR_KEYS.wellKnownLlmProvider(providerEndpoint) : null,
providerEndpoint
? `/api/admin/llm/built-in/options/${providerEndpoint}`
: null,
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);
@@ -166,11 +165,10 @@ export function useWellKnownLLMProviders() {
isLoading,
mutate,
} = useSWR<WellKnownLLMProviderDescriptor[]>(
SWR_KEYS.wellKnownLlmProviders,
"/api/admin/llm/built-in/options",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);

View File

@@ -3,7 +3,6 @@ import useSWR from "swr";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { LicenseStatus } from "@/lib/billing/interfaces";
import { SWR_KEYS } from "@/lib/swr-keys";
/**
* Hook to fetch license status for self-hosted deployments.
@@ -11,7 +10,7 @@ import { SWR_KEYS } from "@/lib/swr-keys";
* Skips the fetch on cloud deployments (uses tenant auth instead).
*/
export function useLicense() {
const url = NEXT_PUBLIC_CLOUD_ENABLED ? null : SWR_KEYS.license;
const url = NEXT_PUBLIC_CLOUD_ENABLED ? null : "/api/license";
const { data, error, mutate, isLoading } = useSWR<LicenseStatus>(
url,
@@ -19,7 +18,6 @@ export function useLicense() {
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
revalidateIfStale: false,
dedupingInterval: 30000,
shouldRetryOnError: false,
keepPreviousData: true,

View File

@@ -3,17 +3,11 @@
import useSWR from "swr";
import { InputPrompt } from "@/app/app/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
export default function usePromptShortcuts() {
const { data, error, isLoading, mutate } = useSWR<InputPrompt[]>(
SWR_KEYS.promptShortcuts,
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
"/api/input_prompt",
errorHandlingFetcher
);
const promptShortcuts = data ?? [];

View File

@@ -1,6 +1,5 @@
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
import {
Settings,
EnterpriseSettings,
@@ -33,12 +32,11 @@ export function useSettings(): {
error: Error | undefined;
} {
const { data, error, isLoading } = useSWR<Settings>(
SWR_KEYS.settings,
"/api/settings",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
revalidateIfStale: false,
dedupingInterval: 30_000,
errorRetryInterval: SETTINGS_ERROR_RETRY_INTERVAL,
}
@@ -63,12 +61,11 @@ export function useEnterpriseSettings(eeEnabledRuntime: boolean): {
const shouldFetch = EE_ENABLED || eeEnabledRuntime;
const { data, error, isLoading } = useSWR<EnterpriseSettings>(
shouldFetch ? SWR_KEYS.enterpriseSettings : null,
shouldFetch ? "/api/enterprise-settings" : null,
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
revalidateIfStale: false,
dedupingInterval: 30_000,
errorRetryInterval: SETTINGS_ERROR_RETRY_INTERVAL,
// Referential equality instead of SWR's default deep comparison.
@@ -92,12 +89,11 @@ export function useCustomAnalyticsScript(
const shouldFetch = EE_ENABLED || eeEnabledRuntime;
const { data } = useSWR<string>(
shouldFetch ? SWR_KEYS.customAnalyticsScript : null,
shouldFetch ? "/api/enterprise-settings/custom-analytics-script" : null,
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateOnReconnect: false,
revalidateIfStale: false,
dedupingInterval: 60_000,
}
);

View File

@@ -1,7 +1,6 @@
import useSWR from "swr";
import { Tag } from "@/lib/types";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
interface TagsResponse {
tags: Tag[];
@@ -19,11 +18,10 @@ interface TagsResponse {
*/
export default function useTags() {
const { data, error, mutate } = useSWR<TagsResponse>(
SWR_KEYS.tags,
"/api/query/valid-tags",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);

View File

@@ -1,6 +1,5 @@
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
export interface VoiceProviderView {
id: number;
@@ -15,13 +14,14 @@ export interface VoiceProviderView {
target_uri: string | null;
}
const VOICE_PROVIDERS_URL = "/api/admin/voice/providers";
export function useVoiceProviders() {
const { data, error, isLoading, mutate } = useSWR<VoiceProviderView[]>(
SWR_KEYS.voiceProviders,
VOICE_PROVIDERS_URL,
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);

View File

@@ -6,8 +6,6 @@ import { INTERNAL_URL, IS_DEV } from "@/lib/constants";
const TARGET_SAMPLE_RATE = 24000;
const CHUNK_INTERVAL_MS = 250;
const DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS = 1500;
// When VAD-based auto-stop is disabled, force-stop after this much silence as a fallback
const SILENCE_FALLBACK_TIMEOUT_MS = 10000;
interface TranscriptMessage {
type: "transcript" | "error";
@@ -60,8 +58,6 @@ class VoiceRecorderSession {
private finalTranscriptDelivered = false;
private lastDeliveredFinalText: string | null = null;
private lastDeliveredFinalAtMs = 0;
// Fallback timer: force-stop after extended silence when VAD auto-stop is disabled
private silenceFallbackTimer: NodeJS.Timeout | null = null;
// Callbacks to update React state
private onTranscriptChange: (text: string) => void;
@@ -178,8 +174,6 @@ class VoiceRecorderSession {
async stop(): Promise<string | null> {
if (!this.isActive) return this.transcript || null;
this.resetSilenceFallbackTimer();
// Stop audio capture
if (this.sendInterval) {
clearInterval(this.sendInterval);
@@ -225,7 +219,6 @@ class VoiceRecorderSession {
}
cleanup(): void {
this.resetSilenceFallbackTimer();
if (this.sendInterval) clearInterval(this.sendInterval);
if (this.scriptNode) this.scriptNode.disconnect();
if (this.sourceNode) this.sourceNode.disconnect();
@@ -281,23 +274,6 @@ class VoiceRecorderSession {
});
}
private resetSilenceFallbackTimer(): void {
if (this.silenceFallbackTimer) {
clearTimeout(this.silenceFallbackTimer);
this.silenceFallbackTimer = null;
}
}
private startSilenceFallbackTimer(): void {
this.resetSilenceFallbackTimer();
this.silenceFallbackTimer = setTimeout(() => {
// 10s of silence with no new speech — force-stop as a safety fallback
if (this.isActive && this.onVADStop) {
this.onVADStop();
}
}, SILENCE_FALLBACK_TIMEOUT_MS);
}
private handleMessage = (event: MessageEvent): void => {
try {
const data: TranscriptMessage = JSON.parse(event.data);
@@ -305,53 +281,47 @@ class VoiceRecorderSession {
if (data.type === "transcript") {
if (data.text) {
this.transcript = data.text;
// Only push live updates to React while actively recording.
// After stop(), the final transcript is returned via stopResolver
// instead — this prevents stale text from reappearing in the
// input box when the user clears it and starts a new recording.
if (this.isActive) {
this.onTranscriptChange(data.text);
}
this.onTranscriptChange(data.text);
}
if (data.is_final && data.text) {
// Resolve stop promise if waiting — must run even after stop()
// so the caller receives the final transcript.
if (this.stopResolver) {
this.stopResolver(data.text);
this.stopResolver = null;
// VAD detected silence - trigger callback (only once per utterance)
const now = Date.now();
const isLikelyDuplicateFinal =
this.autoStopOnSilence &&
this.lastDeliveredFinalText === data.text &&
now - this.lastDeliveredFinalAtMs <
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
if (
this.onFinalTranscript &&
!this.finalTranscriptDelivered &&
!isLikelyDuplicateFinal
) {
this.finalTranscriptDelivered = true;
this.lastDeliveredFinalText = data.text;
this.lastDeliveredFinalAtMs = now;
this.onFinalTranscript(data.text);
}
// Skip VAD logic if session is no longer active
if (!this.isActive) return;
// Auto-stop recording if enabled
if (this.autoStopOnSilence) {
// VAD detected silence — auto-stop and trigger callback
const now = Date.now();
const isLikelyDuplicateFinal =
this.lastDeliveredFinalText === data.text &&
now - this.lastDeliveredFinalAtMs <
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
if (
this.onFinalTranscript &&
!this.finalTranscriptDelivered &&
!isLikelyDuplicateFinal
) {
this.finalTranscriptDelivered = true;
this.lastDeliveredFinalText = data.text;
this.lastDeliveredFinalAtMs = now;
this.onFinalTranscript(data.text);
}
// Trigger stop callback to update React state
if (this.onVADStop) {
this.onVADStop();
}
} else {
// Auto-stop disabled (push-to-talk): ignore VAD, keep recording.
// Start/reset a 10s fallback timer — if no new speech arrives,
// force-stop to avoid recording silence indefinitely.
this.startSilenceFallbackTimer();
// If not auto-stopping, reset for next utterance
this.transcript = "";
this.finalTranscriptDelivered = false;
this.onTranscriptChange("");
this.resetBackendTranscript();
}
// Resolve stop promise if waiting
if (this.stopResolver) {
this.stopResolver(data.text);
this.stopResolver = null;
}
}
} else if (data.type === "error") {

View File

@@ -1,6 +1,5 @@
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
interface VoiceStatus {
stt_enabled: boolean;
@@ -9,11 +8,10 @@ interface VoiceStatus {
export function useVoiceStatus() {
const { data, error, isLoading } = useSWR<VoiceStatus>(
SWR_KEYS.voiceStatus,
"/api/voice/status",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);

View File

@@ -7,7 +7,6 @@ import * as Sentry from "@sentry/nextjs";
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
Sentry.init({
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
release: process.env.SENTRY_RELEASE,
// Setting this option to true will print useful information to the console while you're setting up Sentry.
debug: false,

View File

@@ -160,31 +160,6 @@ export const formatDateShort = (dateStr: string | null | undefined): string => {
});
};
/**
* Format an ISO timestamp as "YYYY/MM/DD HH:MM:SS" (24-hour, local time).
* Intended for log displays where full precision is needed.
*/
export function formatDateTimeLog(iso: string): string {
const d = new Date(iso);
const pad = (n: number) => String(n).padStart(2, "0");
return `${d.getFullYear()}/${pad(d.getMonth() + 1)}/${pad(d.getDate())} ${pad(
d.getHours()
)}:${pad(d.getMinutes())}:${pad(d.getSeconds())}`;
}
/**
* Format an ISO timestamp as "HH:MM:SS" (24-hour, local time).
* Intended for compact time-only displays.
*/
export function formatTimeOnly(iso: string): string {
return new Date(iso).toLocaleTimeString(undefined, {
hour: "2-digit",
minute: "2-digit",
second: "2-digit",
hour12: false,
});
}
export function formatMmDdYyyy(d: string): string {
const date = new Date(d);
return `${date.getMonth() + 1}/${date.getDate()}/${date.getFullYear()}`;

View File

@@ -96,54 +96,6 @@ describe("LLM resolver helpers", () => {
});
});
test("prefers provider by name when multiple share the same type", () => {
const providers: LLMProviderDescriptor[] = [
makeProvider({
id: 1,
name: "Anthropic",
provider: "anthropic",
model_configurations: [
{
name: "claude-sonnet-4-5",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
},
],
}),
makeProvider({
id: 2,
name: "PersonalAnthropicToken",
provider: "anthropic",
model_configurations: [
{
name: "claude-sonnet-4-5",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
},
],
}),
];
const descriptor = getValidLlmDescriptorForProviders(
structureValue(
"PersonalAnthropicToken",
"anthropic",
"claude-sonnet-4-5"
),
providers
);
expect(descriptor).toEqual({
name: "PersonalAnthropicToken",
provider: "anthropic",
modelName: "claude-sonnet-4-5",
});
});
test("uses first provider with models when no explicit default exists", () => {
const providers: LLMProviderDescriptor[] = [
makeProvider({

View File

@@ -603,16 +603,13 @@ export function getValidLlmDescriptorForProviders(
// This ensures we don't incorrectly match a model to the wrong provider
// when the same model name exists across multiple providers (e.g., gpt-5 in Azure and OpenAI)
if (model.provider && model.provider.length > 0) {
const hasModel = (p: LLMProviderDescriptor) =>
p.model_configurations.some((mc) => mc.name === model.modelName);
const typeMatches = llmProviders.filter(
(p) => p.provider === model.provider && hasModel(p)
const matchingProvider = llmProviders.find(
(p) =>
p.provider === model.provider &&
p.model_configurations
.map((modelConfiguration) => modelConfiguration.name)
.includes(model.modelName)
);
// When multiple providers share the same type (e.g., two "anthropic"
// providers with different API keys), prefer the one whose name matches
// the user's explicit selection to avoid silently switching providers.
const matchingProvider =
typeMatches.find((p) => p.name === model.name) ?? typeMatches[0];
if (matchingProvider) {
return {
...model,

View File

@@ -1,15 +1,13 @@
import useSWR from "swr";
import { DocumentSetSummary } from "@/lib/types";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
export function useDocumentSets() {
const { data, error, mutate } = useSWR<DocumentSetSummary[]>(
SWR_KEYS.documentSets,
"/api/manage/document-set",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
}
);

View File

@@ -1,15 +1,13 @@
import useSWR from "swr";
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
export function useLLMProviderOptions() {
const { data, error, mutate } = useSWR<
WellKnownLLMProviderDescriptor[] | undefined
>(SWR_KEYS.wellKnownLlmProviders, errorHandlingFetcher, {
>("/api/admin/llm/built-in/options", errorHandlingFetcher, {
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 60000,
dedupingInterval: 60000, // Dedupe requests within 1 minute
});
return {

View File

@@ -1,15 +1,13 @@
import useSWR from "swr";
import { Project } from "@/app/app/projects/projectsService";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
export function useProjects() {
const { data, error, mutate } = useSWR<Project[]>(
SWR_KEYS.userProjects,
"/api/user/projects",
errorHandlingFetcher,
{
revalidateOnFocus: false,
revalidateIfStale: false,
dedupingInterval: 30000,
}
);

Some files were not shown because too many files have changed in this diff Show More