mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-31 20:42:41 +00:00
Compare commits
6 Commits
rag-script
...
multi-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c222348ff3 | ||
|
|
4ae5e96fab | ||
|
|
3365a369e2 | ||
|
|
470bda3fb5 | ||
|
|
13f511e209 | ||
|
|
c5e8ba1eab |
6
.github/workflows/deployment.yml
vendored
6
.github/workflows/deployment.yml
vendored
@@ -704,9 +704,6 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
SENTRY_RELEASE=${{ github.sha }}
|
||||
secrets: |
|
||||
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
@@ -789,9 +786,6 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
SENTRY_RELEASE=${{ github.sha }}
|
||||
secrets: |
|
||||
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from redis.lock import Lock as RedisLock
|
||||
from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
@@ -30,10 +29,9 @@ from shared_configs.configs import TENANT_ID_PREFIX
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
|
||||
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
|
||||
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
|
||||
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -93,7 +91,8 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
|
||||
if batch_size < tenants_to_provision:
|
||||
task_logger.info(
|
||||
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
|
||||
f"Capping batch to {batch_size} "
|
||||
f"(need {tenants_to_provision}, will catch up next cycle)"
|
||||
)
|
||||
|
||||
provisioned = 0
|
||||
@@ -104,14 +103,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
provisioned += 1
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, "
|
||||
"continuing with remaining tenants"
|
||||
)
|
||||
|
||||
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
|
||||
|
||||
# Migrate any pool tenants that were provisioned before a new migration was deployed
|
||||
_migrate_stale_pool_tenants()
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
@@ -124,46 +121,6 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
)
|
||||
|
||||
|
||||
def _migrate_stale_pool_tenants() -> None:
|
||||
"""
|
||||
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
|
||||
idempotent, tenants already at head are a fast no-op. This ensures pool
|
||||
tenants are always current so that signup doesn't hit schema mismatches
|
||||
(e.g. missing columns added after the tenant was pre-provisioned).
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
pool_tenants = db_session.query(AvailableTenant).all()
|
||||
tenant_ids = [t.tenant_id for t in pool_tenants]
|
||||
|
||||
if not tenant_ids:
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
|
||||
)
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
try:
|
||||
run_alembic_migrations(tenant_id)
|
||||
new_version = get_current_alembic_version(tenant_id)
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
if tenant and tenant.alembic_version != new_version:
|
||||
task_logger.info(
|
||||
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
|
||||
)
|
||||
tenant.alembic_version = new_version
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to migrate pool tenant {tenant_id}, skipping"
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> bool:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
|
||||
@@ -99,26 +99,6 @@ async def get_or_provision_tenant(
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# Run migrations to ensure the pre-provisioned tenant schema is current.
|
||||
# Pool tenants may have been created before a new migration was deployed.
|
||||
# Capture as a non-optional local so mypy can type the lambda correctly.
|
||||
_tenant_id: str = tenant_id
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: run_alembic_migrations(_tenant_id)
|
||||
)
|
||||
except Exception:
|
||||
# The tenant was already dequeued from the pool — roll it back so
|
||||
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
|
||||
logger.exception(
|
||||
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
|
||||
)
|
||||
try:
|
||||
await rollback_tenant_provisioning(_tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
|
||||
raise
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
|
||||
@@ -100,7 +100,6 @@ def get_model_app() -> FastAPI:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -20,7 +20,6 @@ from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
@@ -66,7 +65,6 @@ if SENTRY_DSN:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -517,8 +515,7 @@ def reset_tenant_id(
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(
|
||||
sender: Any, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
sender: Any, **kwargs: Any # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
@@ -9,7 +9,6 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
@@ -138,7 +137,6 @@ def _docfetching_task(
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -319,11 +319,6 @@ def monitor_indexing_attempt_progress(
|
||||
)
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
total_batches: int | str = (
|
||||
coordination_status.total_batches
|
||||
if coordination_status.total_batches is not None
|
||||
else "?"
|
||||
)
|
||||
if coordination_status.found:
|
||||
task_logger.info(
|
||||
f"Indexing attempt progress: "
|
||||
@@ -331,7 +326,7 @@ def monitor_indexing_attempt_progress(
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id} "
|
||||
f"completed_batches={coordination_status.completed_batches} "
|
||||
f"total_batches={total_batches} "
|
||||
f"total_batches={coordination_status.total_batches or '?'} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
f"elapsed={(current_db_time - attempt.time_created).seconds}"
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -159,114 +148,3 @@ class ChatStateContainer:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,19 +1,45 @@
|
||||
import logging
|
||||
import queue
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
Tags every packet with ``model_index`` and places it on ``merged_queue``
|
||||
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
|
||||
|
||||
Args:
|
||||
merged_queue: Shared queue owned by ``_run_models``.
|
||||
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merged_queue: Queue[Any],
|
||||
model_idx: int = 0,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
placement=base.model_copy(update={"model_index": self._model_idx}),
|
||||
obj=packet.obj,
|
||||
)
|
||||
try:
|
||||
self._merged_queue.put((self._model_idx, tagged), timeout=3.0)
|
||||
except queue.Full:
|
||||
# Drain loop is gone (e.g. GeneratorExit on disconnect); discard packet.
|
||||
logger.warning(
|
||||
"Emitter model_idx=%d: queue full after 3s timeout, dropping packet %s",
|
||||
self._model_idx,
|
||||
type(packet.obj).__name__,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -212,7 +212,6 @@ class DocumentSource(str, Enum):
|
||||
PRODUCTBOARD = "productboard"
|
||||
FILE = "file"
|
||||
CODA = "coda"
|
||||
CANVAS = "canvas"
|
||||
NOTION = "notion"
|
||||
ZULIP = "zulip"
|
||||
LINEAR = "linear"
|
||||
@@ -673,7 +672,6 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
DocumentSource.SLAB: "slab data",
|
||||
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
|
||||
DocumentSource.FILE: "files",
|
||||
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
|
||||
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
|
||||
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
|
||||
project management, and collaboration tools into a single, customizable platform",
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
Permissioning / AccessControl logic for Canvas courses.
|
||||
|
||||
CE stub — returns None (no permissions). The EE implementation is loaded
|
||||
at runtime via ``fetch_versioned_implementation``.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def get_course_permissions(
|
||||
canvas_client: CanvasApiClient,
|
||||
course_id: int,
|
||||
) -> ExternalAccess | None:
|
||||
if not global_version.is_ee_version():
|
||||
return None
|
||||
|
||||
ee_get_course_permissions = cast(
|
||||
Callable[[CanvasApiClient, int], ExternalAccess | None],
|
||||
fetch_versioned_implementation(
|
||||
"onyx.external_permissions.canvas.access",
|
||||
"get_course_permissions",
|
||||
),
|
||||
)
|
||||
|
||||
return ee_get_course_permissions(canvas_client, course_id)
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -191,22 +190,3 @@ class CanvasApiClient:
|
||||
if clean_endpoint:
|
||||
final_url += "/" + clean_endpoint
|
||||
return final_url
|
||||
|
||||
def paginate(
|
||||
self,
|
||||
endpoint: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Iterator[list[Any]]:
|
||||
"""Yield each page of results, following Link-header pagination.
|
||||
|
||||
Makes the first request with endpoint + params, then follows
|
||||
next_url from Link headers for subsequent pages.
|
||||
"""
|
||||
response, next_url = self.get(endpoint, params=params)
|
||||
while True:
|
||||
if not response:
|
||||
break
|
||||
yield response
|
||||
if not next_url:
|
||||
break
|
||||
response, next_url = self.get(full_url=next_url)
|
||||
|
||||
@@ -1,82 +1,17 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import NoReturn
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.canvas.access import get_course_permissions
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
"""Map Canvas API errors to connector framework exceptions."""
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Canvas API token is invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Canvas API token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
|
||||
)
|
||||
elif e.status_code >= 500:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Canvas API error (status={e.status_code}): {e}"
|
||||
)
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
course_code: str | None = None
|
||||
created_at: str | None = None
|
||||
workflow_state: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
name=payload.get("name"),
|
||||
course_code=payload.get("course_code"),
|
||||
created_at=payload.get("created_at"),
|
||||
workflow_state=payload.get("workflow_state"),
|
||||
)
|
||||
name: str
|
||||
course_code: str
|
||||
created_at: str
|
||||
workflow_state: str
|
||||
|
||||
|
||||
class CanvasPage(BaseModel):
|
||||
@@ -84,22 +19,10 @@ class CanvasPage(BaseModel):
|
||||
url: str
|
||||
title: str
|
||||
body: str | None = None
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
course_id: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
|
||||
return cls(
|
||||
page_id=payload["page_id"],
|
||||
url=payload["url"],
|
||||
title=payload["title"],
|
||||
body=payload.get("body"),
|
||||
created_at=payload.get("created_at"),
|
||||
updated_at=payload.get("updated_at"),
|
||||
course_id=course_id,
|
||||
)
|
||||
|
||||
|
||||
class CanvasAssignment(BaseModel):
|
||||
id: int
|
||||
@@ -107,23 +30,10 @@ class CanvasAssignment(BaseModel):
|
||||
description: str | None = None
|
||||
html_url: str
|
||||
course_id: int
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
due_at: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
name=payload["name"],
|
||||
description=payload.get("description"),
|
||||
html_url=payload["html_url"],
|
||||
course_id=course_id,
|
||||
created_at=payload.get("created_at"),
|
||||
updated_at=payload.get("updated_at"),
|
||||
due_at=payload.get("due_at"),
|
||||
)
|
||||
|
||||
|
||||
class CanvasAnnouncement(BaseModel):
|
||||
id: int
|
||||
@@ -133,17 +43,6 @@ class CanvasAnnouncement(BaseModel):
|
||||
posted_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
title=payload["title"],
|
||||
message=payload.get("message"),
|
||||
html_url=payload["html_url"],
|
||||
posted_at=payload.get("posted_at"),
|
||||
course_id=course_id,
|
||||
)
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
@@ -173,286 +72,3 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
|
||||
|
||||
class CanvasConnector(
|
||||
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
canvas_base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
|
||||
self.batch_size = batch_size
|
||||
self._canvas_client: CanvasApiClient | None = None
|
||||
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
|
||||
|
||||
@property
|
||||
def canvas_client(self) -> CanvasApiClient:
|
||||
if self._canvas_client is None:
|
||||
raise ConnectorMissingCredentialError("Canvas")
|
||||
return self._canvas_client
|
||||
|
||||
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
|
||||
"""Get course permissions with caching."""
|
||||
if course_id not in self._course_permissions_cache:
|
||||
self._course_permissions_cache[course_id] = get_course_permissions(
|
||||
canvas_client=self.canvas_client,
|
||||
course_id=course_id,
|
||||
)
|
||||
return self._course_permissions_cache[course_id]
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_courses(self) -> list[CanvasCourse]:
|
||||
"""Fetch all courses accessible to the authenticated user."""
|
||||
logger.debug("Fetching Canvas courses")
|
||||
|
||||
courses: list[CanvasCourse] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
"courses", params={"per_page": "100", "state[]": "available"}
|
||||
):
|
||||
courses.extend(CanvasCourse.from_api(c) for c in page)
|
||||
return courses
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_pages(self, course_id: int) -> list[CanvasPage]:
|
||||
"""Fetch all pages for a given course."""
|
||||
logger.debug(f"Fetching pages for course {course_id}")
|
||||
|
||||
pages: list[CanvasPage] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
f"courses/{course_id}/pages",
|
||||
params={"per_page": "100", "include[]": "body", "published": "true"},
|
||||
):
|
||||
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
|
||||
return pages
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
|
||||
"""Fetch all assignments for a given course."""
|
||||
logger.debug(f"Fetching assignments for course {course_id}")
|
||||
|
||||
assignments: list[CanvasAssignment] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
f"courses/{course_id}/assignments",
|
||||
params={"per_page": "100", "published": "true"},
|
||||
):
|
||||
assignments.extend(
|
||||
CanvasAssignment.from_api(a, course_id=course_id) for a in page
|
||||
)
|
||||
return assignments
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
|
||||
"""Fetch all announcements for a given course."""
|
||||
logger.debug(f"Fetching announcements for course {course_id}")
|
||||
|
||||
announcements: list[CanvasAnnouncement] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
"announcements",
|
||||
params={
|
||||
"per_page": "100",
|
||||
"context_codes[]": f"course_{course_id}",
|
||||
"active_only": "true",
|
||||
},
|
||||
):
|
||||
announcements.extend(
|
||||
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
|
||||
)
|
||||
return announcements
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
link: str,
|
||||
text: str,
|
||||
semantic_identifier: str,
|
||||
doc_updated_at: datetime | None,
|
||||
course_id: int,
|
||||
doc_type: str,
|
||||
) -> Document:
|
||||
"""Build a Document with standard Canvas fields."""
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=cast(
|
||||
list[TextSection | ImageSection],
|
||||
[TextSection(link=link, text=text)],
|
||||
),
|
||||
source=DocumentSource.CANVAS,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=doc_updated_at,
|
||||
metadata={"course_id": str(course_id), "type": doc_type},
|
||||
)
|
||||
|
||||
def _convert_page_to_document(self, page: CanvasPage) -> Document:
|
||||
"""Convert a Canvas page to a Document."""
|
||||
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
|
||||
|
||||
text_parts = [page.title]
|
||||
body_text = parse_html_page_basic(page.body) if page.body else ""
|
||||
if body_text:
|
||||
text_parts.append(body_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
if page.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
|
||||
link=link,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=page.title or f"Page {page.page_id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=page.course_id,
|
||||
doc_type="page",
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
|
||||
"""Convert a Canvas assignment to a Document."""
|
||||
text_parts = [assignment.name]
|
||||
desc_text = (
|
||||
parse_html_page_basic(assignment.description)
|
||||
if assignment.description
|
||||
else ""
|
||||
)
|
||||
if desc_text:
|
||||
text_parts.append(desc_text)
|
||||
if assignment.due_at:
|
||||
due_dt = datetime.fromisoformat(
|
||||
assignment.due_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
assignment.updated_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if assignment.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
|
||||
link=assignment.html_url,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=assignment.course_id,
|
||||
doc_type="assignment",
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_announcement_to_document(
|
||||
self, announcement: CanvasAnnouncement
|
||||
) -> Document:
|
||||
"""Convert a Canvas announcement to a Document."""
|
||||
text_parts = [announcement.title]
|
||||
msg_text = (
|
||||
parse_html_page_basic(announcement.message) if announcement.message else ""
|
||||
)
|
||||
if msg_text:
|
||||
text_parts.append(msg_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
announcement.posted_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if announcement.posted_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
|
||||
link=announcement.html_url,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=announcement.course_id,
|
||||
doc_type="announcement",
|
||||
)
|
||||
return document
|
||||
|
||||
@override
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load and validate Canvas credentials."""
|
||||
access_token = credentials.get("canvas_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Canvas")
|
||||
|
||||
try:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=access_token,
|
||||
canvas_base_url=self.canvas_base_url,
|
||||
)
|
||||
client.get("courses", params={"per_page": "1"})
|
||||
except ValueError as e:
|
||||
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
|
||||
except OnyxError as e:
|
||||
_handle_canvas_api_error(e)
|
||||
|
||||
self._canvas_client = client
|
||||
return None
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Canvas connector settings by testing API access."""
|
||||
try:
|
||||
self.canvas_client.get("courses", params={"per_page": "1"})
|
||||
logger.info("Canvas connector settings validated successfully")
|
||||
except OnyxError as e:
|
||||
_handle_canvas_api_error(e)
|
||||
except ConnectorMissingCredentialError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Canvas settings validation: {exc}"
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
# TODO(benwu408): implemented in PR4 (perm sync)
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -72,10 +72,6 @@ CONNECTOR_CLASS_MAP = {
|
||||
module_path="onyx.connectors.coda.connector",
|
||||
class_name="CodaConnector",
|
||||
),
|
||||
DocumentSource.CANVAS: ConnectorMapping(
|
||||
module_path="onyx.connectors.canvas.connector",
|
||||
class_name="CanvasConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
|
||||
@@ -617,6 +617,92 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -839,6 +925,8 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
def search_for_document_ids(
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
|
||||
) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
|
||||
@@ -60,7 +60,8 @@ class OpenSearchSearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
RANDOM = "random"
|
||||
DOC_ID_RETRIEVAL = "doc_id_retrieval"
|
||||
ID_RETRIEVAL = "id_retrieval"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
|
||||
@@ -928,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
|
||||
search_type=OpenSearchSearchType.ID_RETRIEVAL,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
|
||||
@@ -8,6 +8,24 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -439,7 +439,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -46,6 +47,7 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -60,6 +62,8 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -570,6 +575,46 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -660,6 +705,30 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -2,11 +2,25 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to. ``0`` for single-model
|
||||
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
|
||||
for pre-LLM setup packets (e.g. message ID info) that are yielded
|
||||
before any Emitter runs.
|
||||
"""
|
||||
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -708,7 +709,6 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
emitter_queue: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(merged_queue=emitter_queue)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
print(f"Total packets emitted: {emitter_queue.qsize()}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
@@ -11,7 +12,6 @@ import requests
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
# Use default emitter if none provided
|
||||
# Use a discard emitter if none provided (packets go nowhere)
|
||||
if emitter is None:
|
||||
emitter = get_default_emitter()
|
||||
emitter = Emitter(merged_queue=queue.Queue())
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
dynamic_schema_info=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -28,9 +27,6 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
|
||||
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
|
||||
MAX_REQUEST_ATTEMPTS = 5
|
||||
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
QUESTION_TIMEOUT_SECONDS = 300
|
||||
QUESTION_RETRY_PAUSE_SECONDS = 30
|
||||
MAX_QUESTION_ATTEMPTS = 3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -113,27 +109,6 @@ def normalize_api_base(api_base: str) -> str:
|
||||
return f"{normalized}/api"
|
||||
|
||||
|
||||
def load_completed_question_ids(output_file: Path) -> set[str]:
|
||||
if not output_file.exists():
|
||||
return set()
|
||||
|
||||
completed_ids: set[str] = set()
|
||||
with output_file.open("r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
question_id = record.get("question_id")
|
||||
if isinstance(question_id, str) and question_id:
|
||||
completed_ids.add(question_id)
|
||||
|
||||
return completed_ids
|
||||
|
||||
|
||||
def load_questions(questions_file: Path) -> list[QuestionRecord]:
|
||||
if not questions_file.exists():
|
||||
raise FileNotFoundError(f"Questions file not found: {questions_file}")
|
||||
@@ -373,7 +348,6 @@ async def generate_answers(
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
parallelism: int,
|
||||
skipped: int,
|
||||
) -> None:
|
||||
if parallelism < 1:
|
||||
raise ValueError("`--parallelism` must be at least 1.")
|
||||
@@ -408,178 +382,58 @@ async def generate_answers(
|
||||
write_lock = asyncio.Lock()
|
||||
completed = 0
|
||||
successful = 0
|
||||
stuck_count = 0
|
||||
failed_questions: list[FailedQuestionRecord] = []
|
||||
remaining_count = len(questions)
|
||||
overall_total = remaining_count + skipped
|
||||
question_durations: list[float] = []
|
||||
run_start_time = time.monotonic()
|
||||
|
||||
def print_progress() -> None:
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
elapsed = time.monotonic() - run_start_time
|
||||
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
|
||||
|
||||
done = skipped + completed
|
||||
bar_width = 30
|
||||
filled = (
|
||||
int(bar_width * done / overall_total)
|
||||
if overall_total
|
||||
else bar_width
|
||||
)
|
||||
bar = "█" * filled + "░" * (bar_width - filled)
|
||||
pct = (done / overall_total * 100) if overall_total else 100.0
|
||||
|
||||
parts = (
|
||||
f"\r{bar} {pct:5.1f}% "
|
||||
f"[{done}/{overall_total}] "
|
||||
f"avg {avg_time:.1f}s/q "
|
||||
f"elapsed {elapsed:.0f}s "
|
||||
f"ETA {eta:.0f}s "
|
||||
f"(ok:{successful} fail:{len(failed_questions)}"
|
||||
)
|
||||
if stuck_count:
|
||||
parts += f" stuck:{stuck_count}"
|
||||
if skipped:
|
||||
parts += f" skip:{skipped}"
|
||||
parts += ")"
|
||||
|
||||
sys.stderr.write(parts)
|
||||
sys.stderr.flush()
|
||||
|
||||
print_progress()
|
||||
total = len(questions)
|
||||
|
||||
async def process_question(question_record: QuestionRecord) -> None:
|
||||
nonlocal completed
|
||||
nonlocal successful
|
||||
nonlocal stuck_count
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
|
||||
q_start = time.monotonic()
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await asyncio.wait_for(
|
||||
submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
),
|
||||
timeout=QUESTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
async with progress_lock:
|
||||
stuck_count += 1
|
||||
logger.warning(
|
||||
"Question %s timed out after %ss (attempt %s/%s, "
|
||||
"total stuck: %s) — retrying in %ss",
|
||||
question_record.question_id,
|
||||
QUESTION_TIMEOUT_SECONDS,
|
||||
attempt,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
stuck_count,
|
||||
QUESTION_RETRY_PAUSE_SECONDS,
|
||||
)
|
||||
print_progress()
|
||||
last_error = TimeoutError(
|
||||
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
|
||||
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
)
|
||||
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
|
||||
continue
|
||||
except Exception as exc:
|
||||
duration = time.monotonic() - q_start
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
question_durations.append(duration)
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
return
|
||||
|
||||
duration = time.monotonic() - q_start
|
||||
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
except Exception as exc:
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
successful += 1
|
||||
question_durations.append(duration)
|
||||
print_progress()
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
total,
|
||||
)
|
||||
return
|
||||
|
||||
# All attempts exhausted due to timeouts
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(last_error),
|
||||
)
|
||||
)
|
||||
logger.error(
|
||||
"Question %s failed after %s timeout attempts (%s/%s)",
|
||||
question_record.question_id,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
successful += 1
|
||||
logger.info("Processed %s/%s questions", completed, total)
|
||||
|
||||
await asyncio.gather(
|
||||
*(process_question(question_record) for question_record in questions)
|
||||
)
|
||||
|
||||
# Final newline after progress bar
|
||||
sys.stderr.write("\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
total_elapsed = time.monotonic() - run_start_time
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
|
||||
resume_suffix = (
|
||||
f" — {skipped} previously completed, "
|
||||
f"{skipped + successful}/{overall_total} overall"
|
||||
if skipped
|
||||
else ""
|
||||
)
|
||||
logger.info(
|
||||
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
|
||||
successful,
|
||||
remaining_count,
|
||||
total_elapsed,
|
||||
avg_time,
|
||||
stuck_suffix,
|
||||
resume_suffix,
|
||||
)
|
||||
|
||||
if failed_questions:
|
||||
logger.warning(
|
||||
"%s questions failed:",
|
||||
"Completed with %s failed questions and %s successful questions.",
|
||||
len(failed_questions),
|
||||
successful,
|
||||
)
|
||||
for failed_question in failed_questions:
|
||||
logger.warning(
|
||||
@@ -599,30 +453,7 @@ def main() -> None:
|
||||
raise ValueError("`--max-questions` must be at least 1 when provided.")
|
||||
questions = questions[: args.max_questions]
|
||||
|
||||
completed_ids = load_completed_question_ids(args.output_file)
|
||||
logger.info(
|
||||
"Found %s already-answered question IDs in %s",
|
||||
len(completed_ids),
|
||||
args.output_file,
|
||||
)
|
||||
total_before_filter = len(questions)
|
||||
questions = [q for q in questions if q.question_id not in completed_ids]
|
||||
skipped = total_before_filter - len(questions)
|
||||
|
||||
if skipped:
|
||||
logger.info(
|
||||
"Resuming: %s/%s already answered, %s remaining",
|
||||
skipped,
|
||||
total_before_filter,
|
||||
len(questions),
|
||||
)
|
||||
else:
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
|
||||
if not questions:
|
||||
logger.info("All questions already answered. Nothing to do.")
|
||||
return
|
||||
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
logger.info("Writing answers to %s", args.output_file)
|
||||
|
||||
asyncio.run(
|
||||
@@ -632,7 +463,6 @@ def main() -> None:
|
||||
api_base=api_base,
|
||||
api_key=args.api_key,
|
||||
parallelism=args.parallelism,
|
||||
skipped=skipped,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -27,11 +27,13 @@ def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
model_index: int | None = 0,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
model_index=model_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ This test:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -20,7 +21,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -137,7 +138,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -200,7 +201,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -275,7 +276,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -350,7 +351,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -458,7 +459,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -541,7 +542,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
|
||||
@@ -8,6 +8,7 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -16,7 +17,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -174,7 +175,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -232,7 +233,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -284,7 +285,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -345,7 +346,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -416,7 +417,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -483,7 +484,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -536,7 +537,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=get_default_emitter(),
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
173
backend/tests/unit/onyx/chat/test_emitter.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
All tests use the streaming mode (merged_queue required). Emitter has a single
|
||||
code path — no standalone bus.
|
||||
"""
|
||||
|
||||
import queue
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
|
||||
"""Return (emitter, queue) wired together."""
|
||||
mq: queue.Queue = queue.Queue()
|
||||
return Emitter(merged_queue=mq, model_idx=model_idx), mq
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueRouting:
|
||||
def test_emit_lands_on_merged_queue(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
_, t1 = mq.get_nowait()
|
||||
_, t2 = mq.get_nowait()
|
||||
assert t1.placement.turn_index == 0
|
||||
assert t2.placement.turn_index == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_index tagging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterModelIndexTagging:
|
||||
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
|
||||
"""N=1: default model_idx=0, so packet gets model_index=0."""
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueKey:
|
||||
def test_key_equals_model_idx(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_n1_key_is_zero(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placement field preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterPlacementPreservation:
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
emitter, mq = _make_emitter()
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
emitter, mq = _make_emitter()
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
662
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
662
backend/tests/unit/onyx/chat/test_multi_model_streaming.py
Normal file
@@ -0,0 +1,662 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_ee_version() -> Generator[None, None, None]:
|
||||
"""Reset EE global state after each test.
|
||||
|
||||
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
|
||||
(via the celery client import chain). Without this fixture, the EE flag stays
|
||||
True for the rest of the session and breaks unrelated tests that mock Confluence
|
||||
or other connectors and assume EE is disabled.
|
||||
"""
|
||||
original = global_version._is_ee
|
||||
yield
|
||||
global_version._is_ee = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
|
||||
"""Return the first item yielded by handle_multi_model_stream."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
return next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_yields_error(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_four_overrides_yields_error(self) -> None:
|
||||
"""4 overrides exceeds maximum — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_zero_overrides_yields_error(self) -> None:
|
||||
"""Empty override list yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_deep_research_yields_error(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
|
||||
req = _make_request(deep_research=True)
|
||||
result = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "not supported" in result.error
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override yields error, 2 overrides passes validation."""
|
||||
req = _make_request()
|
||||
# 1 override must yield a StreamingError
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(
|
||||
result, StreamingError
|
||||
), "1 override should yield StreamingError"
|
||||
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
|
||||
# missing session, that's OK — validation itself passed)
|
||||
try:
|
||||
result2 = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
if isinstance(result2, StreamingError) and "2-3" in result2.error:
|
||||
pytest.fail(
|
||||
f"2 overrides should pass validation, got StreamingError: {result2.error}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Any non-validation error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
_RUN_MODELS_PATCHES = [
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
|
||||
]
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
|
||||
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 0
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
emitter = kwargs["emitter"]
|
||||
# _model_idx is always int (0 for N=1, 0/1/2… for N>1)
|
||||
if emitter._model_idx == 0:
|
||||
raise RuntimeError("model 0 failed")
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_completion_handle_called_on_disconnect(self) -> None:
|
||||
"""llm_loop_completion_handle must still be called even when user disconnects.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers wait+completion in finally.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator, which propagates into _run_models. The finally
|
||||
block must call executor.shutdown(wait=True) to wait for LLM threads to
|
||||
finish, then persist their results via llm_loop_completion_handle.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
thread_completed = threading.Event()
|
||||
|
||||
def emit_then_complete(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give generator a yield point), then finish."""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
# Small sleep so executor.shutdown(wait=True) in finally actually waits.
|
||||
time.sleep(0.05)
|
||||
thread_completed.set()
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_complete,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
# cast to Generator so .close() is available; _run_models returns
|
||||
# AnswerStream (= Iterator) but the actual object is always a generator.
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
# Advance to the first yielded packet — generator suspends at `yield item`.
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# GeneratorExit is thrown at the `yield item` suspension point.
|
||||
gen.close()
|
||||
|
||||
# Finally block must have waited for the thread and saved completion.
|
||||
assert (
|
||||
thread_completed.is_set()
|
||||
), "LLM thread must complete before gen.close() returns"
|
||||
assert (
|
||||
mock_handle.call_count == 1
|
||||
), "completion handle must be called for the successful model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
@@ -1,23 +1,15 @@
|
||||
"""Tests for Canvas connector — client, credentials, conversion."""
|
||||
"""Tests for Canvas connector — client (PR1)."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.connectors.canvas.connector import CanvasConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -26,77 +18,6 @@ FAKE_BASE_URL = "https://myschool.instructure.com"
|
||||
FAKE_TOKEN = "fake-canvas-token"
|
||||
|
||||
|
||||
def _mock_course(
|
||||
course_id: int = 1,
|
||||
name: str = "Intro to CS",
|
||||
course_code: str = "CS101",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": course_id,
|
||||
"name": name,
|
||||
"course_code": course_code,
|
||||
"created_at": "2025-01-01T00:00:00Z",
|
||||
"workflow_state": "available",
|
||||
}
|
||||
|
||||
|
||||
def _build_connector(base_url: str = FAKE_BASE_URL) -> CanvasConnector:
|
||||
"""Build a connector with mocked credential validation."""
|
||||
with patch("onyx.connectors.canvas.client.rl_requests") as mock_req:
|
||||
mock_req.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = CanvasConnector(canvas_base_url=base_url)
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
return connector
|
||||
|
||||
|
||||
def _mock_page(
|
||||
page_id: int = 10,
|
||||
title: str = "Syllabus",
|
||||
updated_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"page_id": page_id,
|
||||
"url": "syllabus",
|
||||
"title": title,
|
||||
"body": "<p>Welcome to the course</p>",
|
||||
"created_at": "2025-01-15T00:00:00Z",
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
|
||||
|
||||
def _mock_assignment(
|
||||
assignment_id: int = 20,
|
||||
name: str = "Homework 1",
|
||||
course_id: int = 1,
|
||||
updated_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": assignment_id,
|
||||
"name": name,
|
||||
"description": "<p>Solve these problems</p>",
|
||||
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/assignments/{assignment_id}",
|
||||
"course_id": course_id,
|
||||
"created_at": "2025-01-20T00:00:00Z",
|
||||
"updated_at": updated_at,
|
||||
"due_at": "2025-02-01T23:59:00Z",
|
||||
}
|
||||
|
||||
|
||||
def _mock_announcement(
|
||||
announcement_id: int = 30,
|
||||
title: str = "Class Cancelled",
|
||||
course_id: int = 1,
|
||||
posted_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": announcement_id,
|
||||
"title": title,
|
||||
"message": "<p>No class today</p>",
|
||||
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/discussion_topics/{announcement_id}",
|
||||
"posted_at": posted_at,
|
||||
}
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
json_data: Any = None,
|
||||
@@ -404,57 +325,6 @@ class TestGet:
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient.paginate tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPaginate:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[{"id": 1}, {"id": 2}]
|
||||
)
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert len(pages) == 1
|
||||
assert pages[0] == [{"id": 1}, {"id": 2}]
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_two_pages(self, mock_requests: MagicMock) -> None:
|
||||
next_link = f'<{FAKE_BASE_URL}/api/v1/courses?page=2>; rel="next"'
|
||||
page1 = _mock_response(json_data=[{"id": 1}], link_header=next_link)
|
||||
page2 = _mock_response(json_data=[{"id": 2}])
|
||||
mock_requests.get.side_effect = [page1, page2]
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert len(pages) == 2
|
||||
assert pages[0] == [{"id": 1}]
|
||||
assert pages[1] == [{"id": 2}]
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert pages == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._parse_next_link tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -509,368 +379,3 @@ class TestParseNextLink:
|
||||
|
||||
with pytest.raises(OnyxError, match="must use https"):
|
||||
self.client._parse_next_link(header)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — credential loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadCredentials:
|
||||
def _assert_load_credentials_raises(
|
||||
self,
|
||||
status_code: int,
|
||||
expected_error: type[Exception],
|
||||
mock_requests: MagicMock,
|
||||
) -> None:
|
||||
"""Helper: assert load_credentials raises expected_error for a given status."""
|
||||
mock_requests.get.return_value = _mock_response(status_code, {})
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
with pytest.raises(expected_error):
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_success(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
|
||||
result = connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
|
||||
assert result is None
|
||||
assert connector._canvas_client is not None
|
||||
|
||||
def test_canvas_client_raises_without_credentials(self) -> None:
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
|
||||
with pytest.raises(ConnectorMissingCredentialError):
|
||||
_ = connector.canvas_client
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_invalid_token(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_load_credentials_raises(401, CredentialExpiredError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_insufficient_permissions(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
self._assert_load_credentials_raises(
|
||||
403, InsufficientPermissionsError, mock_requests
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — URL normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectorUrlNormalization:
|
||||
def test_strips_api_v1_suffix(self) -> None:
|
||||
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/api/v1")
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_strips_trailing_slash(self) -> None:
|
||||
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/")
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_no_change_for_clean_url(self) -> None:
|
||||
connector = _build_connector(base_url=FAKE_BASE_URL)
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — document conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentConversion:
|
||||
def setup_method(self) -> None:
|
||||
self.connector = _build_connector()
|
||||
|
||||
def test_convert_page_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasPage
|
||||
|
||||
page = CanvasPage(
|
||||
page_id=10,
|
||||
url="syllabus",
|
||||
title="Syllabus",
|
||||
body="<p>Welcome</p>",
|
||||
created_at="2025-01-15T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_page_to_document(page)
|
||||
|
||||
expected_id = "canvas-page-1-10"
|
||||
expected_metadata = {"course_id": "1", "type": "page"}
|
||||
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Syllabus"
|
||||
assert doc.metadata == expected_metadata
|
||||
assert doc.sections[0].link is not None
|
||||
assert f"{FAKE_BASE_URL}/courses/1/pages/syllabus" in doc.sections[0].link
|
||||
assert doc.doc_updated_at == expected_updated_at
|
||||
|
||||
def test_convert_page_without_body(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasPage
|
||||
|
||||
page = CanvasPage(
|
||||
page_id=11,
|
||||
url="empty-page",
|
||||
title="Empty Page",
|
||||
body=None,
|
||||
created_at="2025-01-15T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_page_to_document(page)
|
||||
section_text = doc.sections[0].text
|
||||
assert section_text is not None
|
||||
|
||||
assert "Empty Page" in section_text
|
||||
assert "<p>" not in section_text
|
||||
|
||||
def test_convert_assignment_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAssignment
|
||||
|
||||
assignment = CanvasAssignment(
|
||||
id=20,
|
||||
name="Homework 1",
|
||||
description="<p>Solve these</p>",
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/20",
|
||||
course_id=1,
|
||||
created_at="2025-01-20T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
due_at="2025-02-01T23:59:00Z",
|
||||
)
|
||||
|
||||
doc = self.connector._convert_assignment_to_document(assignment)
|
||||
|
||||
expected_id = "canvas-assignment-1-20"
|
||||
expected_due_text = "Due: February 01, 2025 23:59 UTC"
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Homework 1"
|
||||
assert doc.sections[0].text is not None
|
||||
assert expected_due_text in doc.sections[0].text
|
||||
|
||||
def test_convert_assignment_without_description(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAssignment
|
||||
|
||||
assignment = CanvasAssignment(
|
||||
id=21,
|
||||
name="Quiz 1",
|
||||
description=None,
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/21",
|
||||
course_id=1,
|
||||
created_at="2025-01-20T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
due_at=None,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_assignment_to_document(assignment)
|
||||
section_text = doc.sections[0].text
|
||||
assert section_text is not None
|
||||
|
||||
assert "Quiz 1" in section_text
|
||||
assert "Due:" not in section_text
|
||||
|
||||
def test_convert_announcement_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAnnouncement
|
||||
|
||||
announcement = CanvasAnnouncement(
|
||||
id=30,
|
||||
title="Class Cancelled",
|
||||
message="<p>No class today</p>",
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/30",
|
||||
posted_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_announcement_to_document(announcement)
|
||||
|
||||
expected_id = "canvas-announcement-1-30"
|
||||
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Class Cancelled"
|
||||
assert doc.doc_updated_at == expected_updated_at
|
||||
|
||||
def test_convert_announcement_without_posted_at(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAnnouncement
|
||||
|
||||
announcement = CanvasAnnouncement(
|
||||
id=31,
|
||||
title="TBD Announcement",
|
||||
message=None,
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/31",
|
||||
posted_at=None,
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_announcement_to_document(announcement)
|
||||
|
||||
assert doc.doc_updated_at is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — validate_connector_settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateConnectorSettings:
|
||||
def _assert_validate_raises(
|
||||
self,
|
||||
status_code: int,
|
||||
expected_error: type[Exception],
|
||||
mock_requests: MagicMock,
|
||||
) -> None:
|
||||
"""Helper: assert validate_connector_settings raises expected_error."""
|
||||
success_resp = _mock_response(json_data=[_mock_course()])
|
||||
fail_resp = _mock_response(status_code, {})
|
||||
mock_requests.get.side_effect = [success_resp, fail_resp]
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
with pytest.raises(expected_error):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_success(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = _build_connector()
|
||||
|
||||
connector.validate_connector_settings() # should not raise
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_expired_credential(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(401, CredentialExpiredError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _list_* pagination tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListCourses:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_course(1), _mock_course(2, "CS201", "Data Structures")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_courses()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 1
|
||||
assert result[1].id == 2
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_courses()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListPages:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_page(10), _mock_page(11, "Notes")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_pages(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].page_id == 10
|
||||
assert result[1].page_id == 11
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_pages(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListAssignments:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_assignment(20), _mock_assignment(21, "Quiz 1")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_assignments(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 20
|
||||
assert result[1].id == 21
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_assignments(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListAnnouncements:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_announcement(30), _mock_announcement(31, "Update")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_announcements(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 30
|
||||
assert result[1].id == 31
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_announcements(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for memory tool streaming packet emissions."""
|
||||
|
||||
from queue import Queue
|
||||
import queue
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,9 +18,13 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter() -> Emitter:
|
||||
bus: Queue = Queue()
|
||||
return Emitter(bus)
|
||||
def emitter_queue() -> queue.Queue:
|
||||
return queue.Queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter(emitter_queue: queue.Queue) -> Emitter:
|
||||
return Emitter(merged_queue=emitter_queue)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -53,24 +57,27 @@ class TestMemoryToolEmitStart:
|
||||
def test_emit_start_emits_memory_tool_start_packet(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
) -> None:
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolStart)
|
||||
assert packet.placement == placement
|
||||
assert packet.placement is not None
|
||||
assert packet.placement.turn_index == placement.turn_index
|
||||
assert packet.placement.tab_index == placement.tab_index
|
||||
assert packet.placement.model_index == 0 # emitter stamps model_index=0
|
||||
|
||||
def test_emit_start_with_different_placement(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
) -> None:
|
||||
placement = Placement(turn_index=2, tab_index=1)
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert packet.placement.turn_index == 2
|
||||
assert packet.placement.tab_index == 1
|
||||
|
||||
@@ -81,7 +88,7 @@ class TestMemoryToolRun:
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -93,21 +100,19 @@ class TestMemoryToolRun:
|
||||
memory="User prefers Python",
|
||||
)
|
||||
|
||||
# The delta packet should be in the queue
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers Python"
|
||||
assert packet.obj.operation == "add"
|
||||
assert packet.obj.memory_id is None
|
||||
assert packet.obj.index is None
|
||||
assert packet.placement == placement
|
||||
|
||||
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
|
||||
def test_run_emits_delta_for_update_operation(
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter: Emitter,
|
||||
emitter_queue: queue.Queue,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -119,7 +124,7 @@ class TestMemoryToolRun:
|
||||
memory="User prefers light mode",
|
||||
)
|
||||
|
||||
packet = emitter.bus.get_nowait()
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers light mode"
|
||||
assert packet.obj.operation == "update"
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.39
|
||||
version: 0.4.38
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,77 +0,0 @@
|
||||
{{- if and .Values.monitoring.serviceMonitors.enabled .Values.vectorDB.enabled }}
|
||||
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
|
||||
---
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.celery_worker_monitoring.deploymentLabels.app }}
|
||||
metrics: "true"
|
||||
endpoints:
|
||||
- port: metrics
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- if gt (int .Values.celery_worker_docfetching.replicaCount) 0 }}
|
||||
---
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.celery_worker_docfetching.deploymentLabels.app }}
|
||||
metrics: "true"
|
||||
endpoints:
|
||||
- port: metrics
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- if gt (int .Values.celery_worker_docprocessing.replicaCount) 0 }}
|
||||
---
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.celery_worker_docprocessing.deploymentLabels.app }}
|
||||
metrics: "true"
|
||||
endpoints:
|
||||
- port: metrics
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -1,15 +0,0 @@
|
||||
{{- if .Values.monitoring.grafana.dashboards.enabled }}
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-indexing-pipeline-dashboard
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
grafana_dashboard: "1"
|
||||
annotations:
|
||||
grafana_folder: "Onyx"
|
||||
data:
|
||||
onyx-indexing-pipeline.json: |
|
||||
{{- .Files.Get "dashboards/indexing-pipeline.json" | nindent 4 }}
|
||||
{{- end }}
|
||||
@@ -256,20 +256,6 @@ tooling:
|
||||
# -- Which client binary to call; change if your image uses a non-default path.
|
||||
psqlBinary: psql
|
||||
|
||||
monitoring:
|
||||
grafana:
|
||||
dashboards:
|
||||
# -- Set to true to deploy Grafana dashboard ConfigMaps for the Onyx indexing pipeline.
|
||||
# Requires kube-prometheus-stack (or equivalent) with the Grafana sidecar enabled and watching this namespace.
|
||||
# The sidecar must be configured with label selector: grafana_dashboard=1
|
||||
enabled: false
|
||||
serviceMonitors:
|
||||
# -- Set to true to deploy ServiceMonitor resources for Celery worker metrics endpoints.
|
||||
# Requires the Prometheus Operator CRDs (included in kube-prometheus-stack).
|
||||
# Use `labels` to match your Prometheus CR's serviceMonitorSelector (e.g. release: onyx-monitoring).
|
||||
enabled: false
|
||||
labels: {}
|
||||
|
||||
serviceAccount:
|
||||
# Specifies whether a service account should be created
|
||||
create: false
|
||||
|
||||
@@ -19,10 +19,6 @@ module "eks" {
|
||||
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
|
||||
enable_cluster_creator_admin_permissions = true
|
||||
|
||||
# Control plane logging
|
||||
cluster_enabled_log_types = var.cluster_enabled_log_types
|
||||
cloudwatch_log_group_retention_in_days = var.cloudwatch_log_group_retention_in_days
|
||||
|
||||
eks_managed_node_group_defaults = {
|
||||
ami_type = "AL2023_x86_64_STANDARD"
|
||||
}
|
||||
|
||||
@@ -161,25 +161,3 @@ variable "rds_db_connect_arn" {
|
||||
description = "Full rds-db:connect ARN to allow (required when enable_rds_iam_for_service_account is true)"
|
||||
default = null
|
||||
}
|
||||
|
||||
variable "cluster_enabled_log_types" {
|
||||
type = list(string)
|
||||
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
|
||||
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
|
||||
|
||||
validation {
|
||||
condition = alltrue([for t in var.cluster_enabled_log_types : contains(["api", "audit", "authenticator", "controllerManager", "scheduler"], t)])
|
||||
error_message = "Each entry must be one of: api, audit, authenticator, controllerManager, scheduler."
|
||||
}
|
||||
}
|
||||
|
||||
variable "cloudwatch_log_group_retention_in_days" {
|
||||
type = number
|
||||
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
|
||||
default = 30
|
||||
|
||||
validation {
|
||||
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.cloudwatch_log_group_retention_in_days)
|
||||
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,9 +54,6 @@ module "postgres" {
|
||||
password = var.postgres_password
|
||||
tags = local.merged_tags
|
||||
enable_rds_iam_auth = var.enable_iam_auth
|
||||
|
||||
backup_retention_period = var.postgres_backup_retention_period
|
||||
backup_window = var.postgres_backup_window
|
||||
}
|
||||
|
||||
module "s3" {
|
||||
@@ -83,10 +80,6 @@ module "eks" {
|
||||
public_cluster_enabled = var.public_cluster_enabled
|
||||
private_cluster_enabled = var.private_cluster_enabled
|
||||
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
|
||||
|
||||
# Control plane logging
|
||||
cluster_enabled_log_types = var.eks_cluster_enabled_log_types
|
||||
cloudwatch_log_group_retention_in_days = var.eks_cloudwatch_log_group_retention_in_days
|
||||
}
|
||||
|
||||
module "waf" {
|
||||
|
||||
@@ -250,34 +250,3 @@ variable "opensearch_subnet_ids" {
|
||||
description = "Subnet IDs for OpenSearch. If empty, uses first 3 private subnets."
|
||||
default = []
|
||||
}
|
||||
|
||||
# RDS Backup Configuration
|
||||
variable "postgres_backup_retention_period" {
|
||||
type = number
|
||||
description = "Number of days to retain automated RDS backups (0 to disable)"
|
||||
default = 7
|
||||
}
|
||||
|
||||
variable "postgres_backup_window" {
|
||||
type = string
|
||||
description = "Preferred UTC time window for automated RDS backups (hh24:mi-hh24:mi)"
|
||||
default = "03:00-04:00"
|
||||
}
|
||||
|
||||
# EKS Control Plane Logging
|
||||
variable "eks_cluster_enabled_log_types" {
|
||||
type = list(string)
|
||||
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
|
||||
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
|
||||
}
|
||||
|
||||
variable "eks_cloudwatch_log_group_retention_in_days" {
|
||||
type = number
|
||||
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
|
||||
default = 30
|
||||
|
||||
validation {
|
||||
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.eks_cloudwatch_log_group_retention_in_days)
|
||||
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,56 +44,5 @@ resource "aws_db_instance" "this" {
|
||||
publicly_accessible = false
|
||||
deletion_protection = true
|
||||
storage_encrypted = true
|
||||
|
||||
# Automated backups
|
||||
backup_retention_period = var.backup_retention_period
|
||||
backup_window = var.backup_window
|
||||
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
# CloudWatch alarm for CPU utilization monitoring
|
||||
resource "aws_cloudwatch_metric_alarm" "cpu_utilization" {
|
||||
alarm_name = "${var.identifier}-cpu-utilization"
|
||||
alarm_description = "RDS CPU utilization for ${var.identifier}"
|
||||
comparison_operator = "GreaterThanThreshold"
|
||||
evaluation_periods = var.cpu_alarm_evaluation_periods
|
||||
metric_name = "CPUUtilization"
|
||||
namespace = "AWS/RDS"
|
||||
period = var.cpu_alarm_period
|
||||
statistic = "Average"
|
||||
threshold = var.cpu_alarm_threshold
|
||||
treat_missing_data = "missing"
|
||||
|
||||
alarm_actions = var.alarm_actions
|
||||
ok_actions = var.alarm_actions
|
||||
|
||||
dimensions = {
|
||||
DBInstanceIdentifier = aws_db_instance.this.identifier
|
||||
}
|
||||
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
# CloudWatch alarm for freeable memory monitoring
|
||||
resource "aws_cloudwatch_metric_alarm" "freeable_memory" {
|
||||
alarm_name = "${var.identifier}-freeable-memory"
|
||||
alarm_description = "RDS freeable memory for ${var.identifier}"
|
||||
comparison_operator = "LessThanThreshold"
|
||||
evaluation_periods = var.memory_alarm_evaluation_periods
|
||||
metric_name = "FreeableMemory"
|
||||
namespace = "AWS/RDS"
|
||||
period = var.memory_alarm_period
|
||||
statistic = "Average"
|
||||
threshold = var.memory_alarm_threshold
|
||||
treat_missing_data = "missing"
|
||||
|
||||
alarm_actions = var.alarm_actions
|
||||
ok_actions = var.alarm_actions
|
||||
|
||||
dimensions = {
|
||||
DBInstanceIdentifier = aws_db_instance.this.identifier
|
||||
}
|
||||
|
||||
tags = var.tags
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
@@ -67,98 +67,3 @@ variable "enable_rds_iam_auth" {
|
||||
description = "Enable AWS IAM database authentication for this RDS instance"
|
||||
default = false
|
||||
}
|
||||
|
||||
variable "backup_retention_period" {
|
||||
type = number
|
||||
description = "Number of days to retain automated backups (0 to disable)"
|
||||
default = 7
|
||||
|
||||
validation {
|
||||
condition = var.backup_retention_period >= 0 && var.backup_retention_period <= 35
|
||||
error_message = "backup_retention_period must be between 0 and 35 (AWS RDS limit)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "backup_window" {
|
||||
type = string
|
||||
description = "Preferred UTC time window for automated backups (hh24:mi-hh24:mi)"
|
||||
default = "03:00-04:00"
|
||||
|
||||
validation {
|
||||
condition = can(regex("^([01]\\d|2[0-3]):[0-5]\\d-([01]\\d|2[0-3]):[0-5]\\d$", var.backup_window))
|
||||
error_message = "backup_window must be in hh24:mi-hh24:mi format (e.g. \"03:00-04:00\")."
|
||||
}
|
||||
}
|
||||
|
||||
# CloudWatch CPU alarm configuration
|
||||
variable "cpu_alarm_threshold" {
|
||||
type = number
|
||||
description = "CPU utilization percentage threshold for the CloudWatch alarm"
|
||||
default = 80
|
||||
|
||||
validation {
|
||||
condition = var.cpu_alarm_threshold >= 0 && var.cpu_alarm_threshold <= 100
|
||||
error_message = "cpu_alarm_threshold must be between 0 and 100 (percentage)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "cpu_alarm_evaluation_periods" {
|
||||
type = number
|
||||
description = "Number of consecutive periods the threshold must be breached before alarming"
|
||||
default = 3
|
||||
|
||||
validation {
|
||||
condition = var.cpu_alarm_evaluation_periods >= 1
|
||||
error_message = "cpu_alarm_evaluation_periods must be at least 1."
|
||||
}
|
||||
}
|
||||
|
||||
variable "cpu_alarm_period" {
|
||||
type = number
|
||||
description = "Period in seconds over which the CPU metric is evaluated"
|
||||
default = 300
|
||||
|
||||
validation {
|
||||
condition = var.cpu_alarm_period >= 60 && var.cpu_alarm_period % 60 == 0
|
||||
error_message = "cpu_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "memory_alarm_threshold" {
|
||||
type = number
|
||||
description = "Freeable memory threshold in bytes. Alarm fires when memory drops below this value."
|
||||
default = 256000000 # 256 MB
|
||||
|
||||
validation {
|
||||
condition = var.memory_alarm_threshold > 0
|
||||
error_message = "memory_alarm_threshold must be greater than 0."
|
||||
}
|
||||
}
|
||||
|
||||
variable "memory_alarm_evaluation_periods" {
|
||||
type = number
|
||||
description = "Number of consecutive periods the threshold must be breached before alarming"
|
||||
default = 3
|
||||
|
||||
validation {
|
||||
condition = var.memory_alarm_evaluation_periods >= 1
|
||||
error_message = "memory_alarm_evaluation_periods must be at least 1."
|
||||
}
|
||||
}
|
||||
|
||||
variable "memory_alarm_period" {
|
||||
type = number
|
||||
description = "Period in seconds over which the freeable memory metric is evaluated"
|
||||
default = 300
|
||||
|
||||
validation {
|
||||
condition = var.memory_alarm_period >= 60 && var.memory_alarm_period % 60 == 0
|
||||
error_message = "memory_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "alarm_actions" {
|
||||
type = list(string)
|
||||
description = "List of ARNs to notify when the alarm transitions state (e.g. SNS topic ARNs)"
|
||||
default = []
|
||||
}
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
{
|
||||
"annotations": {
|
||||
"list": [
|
||||
{
|
||||
"builtIn": 1,
|
||||
"datasource": { "type": "grafana", "uid": "-- Grafana --" },
|
||||
"enable": true,
|
||||
"hide": true,
|
||||
"iconColor": "rgba(0, 211, 255, 1)",
|
||||
"name": "Annotations & Alerts",
|
||||
"type": "dashboard"
|
||||
}
|
||||
]
|
||||
},
|
||||
"editable": true,
|
||||
"fiscalYearStartMonth": 0,
|
||||
"graphTooltip": 1,
|
||||
"id": null,
|
||||
"links": [],
|
||||
"liveNow": true,
|
||||
"panels": [
|
||||
{
|
||||
"title": "Client-Side Search Latency (P50 / P95 / P99)",
|
||||
"description": "End-to-end latency as measured by the Python client, including network round-trip and serialization overhead.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 0 },
|
||||
"id": 1,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "seconds",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "none" },
|
||||
"thresholdsStyle": { "mode": "dashed" }
|
||||
},
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{ "color": "green", "value": null },
|
||||
{ "color": "yellow", "value": 0.5 },
|
||||
{ "color": "red", "value": 2.0 }
|
||||
]
|
||||
},
|
||||
"unit": "s",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P50",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P95",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P99",
|
||||
"refId": "C"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Server-Side Search Latency (P50 / P95 / P99)",
|
||||
"description": "OpenSearch server-side execution time from the 'took' field in the response. Does not include network or client-side overhead.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 0 },
|
||||
"id": 2,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "seconds",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "none" },
|
||||
"thresholdsStyle": { "mode": "dashed" }
|
||||
},
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{ "color": "green", "value": null },
|
||||
{ "color": "yellow", "value": 0.5 },
|
||||
{ "color": "red", "value": 2.0 }
|
||||
]
|
||||
},
|
||||
"unit": "s",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P50",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P95",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P99",
|
||||
"refId": "C"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Client-Side Latency by Search Type (P95)",
|
||||
"description": "P95 client-side latency broken down by search type (hybrid, keyword, semantic, random, doc_id_retrieval).",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 10 },
|
||||
"id": 3,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "seconds",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "none" },
|
||||
"thresholdsStyle": { "mode": "off" }
|
||||
},
|
||||
"unit": "s",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (search_type, le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "{{ search_type }}",
|
||||
"refId": "A"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Search Throughput by Type",
|
||||
"description": "Searches per second broken down by search type.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 10 },
|
||||
"id": 4,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "searches/s",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "normal" },
|
||||
"thresholdsStyle": { "mode": "off" }
|
||||
},
|
||||
"unit": "ops",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (search_type) (rate(onyx_opensearch_search_total[5m]))",
|
||||
"legendFormat": "{{ search_type }}",
|
||||
"refId": "A"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Concurrent Searches In Progress",
|
||||
"description": "Number of OpenSearch searches currently in flight, broken down by search type. Summed across all instances.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 20 },
|
||||
"id": 5,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "searches",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "normal" },
|
||||
"thresholdsStyle": { "mode": "off" }
|
||||
},
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (search_type) (onyx_opensearch_searches_in_progress)",
|
||||
"legendFormat": "{{ search_type }}",
|
||||
"refId": "A"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Client vs Server Latency Overhead (P50)",
|
||||
"description": "Difference between client-side and server-side P50 latency. Reveals network, serialization, and untracked OpenSearch overhead.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 20 },
|
||||
"id": 6,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "seconds",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "none" },
|
||||
"thresholdsStyle": { "mode": "off" }
|
||||
},
|
||||
"unit": "s",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m]))) - histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "Client - Server overhead (P50)",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "Client P50",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "Server P50",
|
||||
"refId": "C"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"refresh": "5s",
|
||||
"schemaVersion": 37,
|
||||
"style": "dark",
|
||||
"tags": ["onyx", "opensearch", "search", "latency"],
|
||||
"templating": {
|
||||
"list": [
|
||||
{
|
||||
"current": {
|
||||
"text": "Prometheus",
|
||||
"value": "prometheus"
|
||||
},
|
||||
"includeAll": false,
|
||||
"name": "DS_PROMETHEUS",
|
||||
"options": [],
|
||||
"query": "prometheus",
|
||||
"refresh": 1,
|
||||
"type": "datasource"
|
||||
}
|
||||
]
|
||||
},
|
||||
"time": { "from": "now-60m", "to": "now" },
|
||||
"timepicker": {
|
||||
"refresh_intervals": ["5s", "10s", "30s", "1m"]
|
||||
},
|
||||
"timezone": "",
|
||||
"title": "Onyx OpenSearch Search Latency",
|
||||
"uid": "onyx-opensearch-search-latency",
|
||||
"version": 0,
|
||||
"weekStart": ""
|
||||
}
|
||||
@@ -73,17 +73,11 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
|
||||
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
|
||||
|
||||
ARG SENTRY_RELEASE
|
||||
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
|
||||
|
||||
# Add NODE_OPTIONS argument
|
||||
ARG NODE_OPTIONS
|
||||
|
||||
# SENTRY_AUTH_TOKEN is injected via BuildKit secret mount so it is never written
|
||||
# to any image layer, build cache, or registry manifest.
|
||||
# Use NODE_OPTIONS in the build command
|
||||
RUN --mount=type=secret,id=sentry_auth_token,env=SENTRY_AUTH_TOKEN \
|
||||
NODE_OPTIONS="${NODE_OPTIONS}" npx next build
|
||||
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
|
||||
|
||||
# Step 2. Production image, copy all the files and run next
|
||||
FROM base AS runner
|
||||
@@ -156,9 +150,6 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
|
||||
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
|
||||
|
||||
ARG SENTRY_RELEASE
|
||||
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
@@ -24,7 +24,6 @@ type TextFont =
|
||||
| "secondary-body"
|
||||
| "secondary-action"
|
||||
| "secondary-mono"
|
||||
| "secondary-mono-label"
|
||||
| "figure-small-label"
|
||||
| "figure-small-value"
|
||||
| "figure-keystroke";
|
||||
@@ -89,7 +88,6 @@ const FONT_CONFIG: Record<TextFont, string> = {
|
||||
"secondary-body": "font-secondary-body",
|
||||
"secondary-action": "font-secondary-action",
|
||||
"secondary-mono": "font-secondary-mono",
|
||||
"secondary-mono-label": "font-secondary-mono-label",
|
||||
"figure-small-label": "font-figure-small-label",
|
||||
"figure-small-value": "font-figure-small-value",
|
||||
"figure-keystroke": "font-figure-keystroke",
|
||||
|
||||
@@ -8,7 +8,6 @@ import * as Sentry from "@sentry/nextjs";
|
||||
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
|
||||
Sentry.init({
|
||||
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
|
||||
release: process.env.SENTRY_RELEASE,
|
||||
// Only capture unhandled exceptions
|
||||
tracesSampleRate: 0,
|
||||
debug: false,
|
||||
|
||||
@@ -7,7 +7,6 @@ import * as Sentry from "@sentry/nextjs";
|
||||
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
|
||||
Sentry.init({
|
||||
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
|
||||
release: process.env.SENTRY_RELEASE,
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
export { default } from "@/refresh-pages/admin/CodeInterpreterPage";
|
||||
"use client";
|
||||
|
||||
import CodeInterpreterPage from "@/refresh-pages/admin/CodeInterpreterPage";
|
||||
|
||||
export default function Page() {
|
||||
return <CodeInterpreterPage />;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
|
||||
import { ImageProvider } from "@/app/admin/configuration/image-generation/constants";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import { ImageGenerationConfigView } from "@/lib/configuration/imageConfigurationService";
|
||||
import { getImageGenForm } from "./forms";
|
||||
|
||||
interface Props {
|
||||
modal: ModalCreationInterface;
|
||||
imageProvider: ImageProvider;
|
||||
existingProviders: LLMProviderView[];
|
||||
existingConfig?: ImageGenerationConfigView;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Modal for creating/editing image generation configurations.
|
||||
* Routes to provider-specific forms based on imageProvider.provider_name.
|
||||
*/
|
||||
export default function ImageGenerationConnectionModal(props: Props) {
|
||||
return <>{getImageGenForm(props)}</>;
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import useSWR from "swr";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
@@ -10,39 +11,24 @@ import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
IMAGE_PROVIDER_GROUPS,
|
||||
ImageProvider,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/constants";
|
||||
} from "@/app/admin/configuration/image-generation/constants";
|
||||
import ImageGenerationConnectionModal from "@/app/admin/configuration/image-generation/ImageGenerationConnectionModal";
|
||||
import {
|
||||
ImageGenerationConfigView,
|
||||
setDefaultImageGenerationConfig,
|
||||
unsetDefaultImageGenerationConfig,
|
||||
deleteImageGenerationConfig,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
|
||||
} from "@/lib/configuration/imageConfigurationService";
|
||||
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { Button, SelectCard, Text } from "@opal/components";
|
||||
import { Content, CardHeaderLayout } from "@opal/layouts";
|
||||
import { Hoverable } from "@opal/core";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgCheckSquare,
|
||||
SvgSettings,
|
||||
SvgSlash,
|
||||
SvgUnplug,
|
||||
} from "@opal/icons";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { getImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
const STATUS_TO_STATE = {
|
||||
disconnected: "empty",
|
||||
connected: "filled",
|
||||
selected: "selected",
|
||||
} as const;
|
||||
|
||||
export default function ImageGenerationContent() {
|
||||
const {
|
||||
data: llmProviderResponse,
|
||||
@@ -212,13 +198,16 @@ export default function ImageGenerationContent() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Content
|
||||
title="Image Generation Model"
|
||||
description="Select a model to generate images in chat."
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Section Header */}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text font="main-content-emphasis" color="text-05">
|
||||
Image Generation Model
|
||||
</Text>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Select a model to generate images in chat.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{connectedProviderIds.size === 0 && (
|
||||
<Message
|
||||
@@ -234,111 +223,32 @@ export default function ImageGenerationContent() {
|
||||
{/* Provider Groups */}
|
||||
{IMAGE_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.name} className="flex flex-col gap-2">
|
||||
<Content title={group.name} sizePreset="secondary" variant="body" />
|
||||
{group.providers.map((provider) => {
|
||||
const status = getStatus(provider);
|
||||
const isDisconnected = status === "disconnected";
|
||||
const isConnected = status === "connected";
|
||||
const isSelected = status === "selected";
|
||||
|
||||
return (
|
||||
<Hoverable.Root
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{group.name}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
{group.providers.map((provider) => (
|
||||
<Select
|
||||
key={provider.image_provider_id}
|
||||
group="image-gen/ProviderCard"
|
||||
>
|
||||
<SelectCard
|
||||
variant="select-card"
|
||||
state={STATUS_TO_STATE[status]}
|
||||
sizeVariant="lg"
|
||||
aria-label={`image-gen-provider-${provider.image_provider_id}`}
|
||||
onClick={
|
||||
isDisconnected
|
||||
? () => handleConnect(provider)
|
||||
: isSelected
|
||||
? () => handleDeselect(provider)
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
<CardHeaderLayout
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={() => (
|
||||
<ProviderIcon
|
||||
provider={provider.provider_name}
|
||||
size={16}
|
||||
/>
|
||||
)}
|
||||
title={provider.title}
|
||||
description={provider.description}
|
||||
rightChildren={
|
||||
isDisconnected ? (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgArrowExchange}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleConnect(provider);
|
||||
}}
|
||||
>
|
||||
Connect
|
||||
</Button>
|
||||
) : isConnected ? (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgArrowRightCircle}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleSelect(provider);
|
||||
}}
|
||||
>
|
||||
Set as Default
|
||||
</Button>
|
||||
) : isSelected ? (
|
||||
<div className="p-2">
|
||||
<Content
|
||||
title="Current Default"
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
icon={SvgCheckSquare}
|
||||
/>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
bottomRightChildren={
|
||||
!isDisconnected ? (
|
||||
<div className="flex flex-row px-1 pb-1">
|
||||
<Hoverable.Item group="image-gen/ProviderCard">
|
||||
<Button
|
||||
icon={SvgUnplug}
|
||||
tooltip="Disconnect"
|
||||
aria-label={`Disconnect ${provider.title}`}
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setDisconnectProvider(provider);
|
||||
}}
|
||||
size="md"
|
||||
/>
|
||||
</Hoverable.Item>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
aria-label={`Edit ${provider.title}`}
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleEdit(provider);
|
||||
}}
|
||||
size="md"
|
||||
/>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</SelectCard>
|
||||
</Hoverable.Root>
|
||||
);
|
||||
})}
|
||||
aria-label={`image-gen-provider-${provider.image_provider_id}`}
|
||||
icon={() => (
|
||||
<ProviderIcon provider={provider.provider_name} size={18} />
|
||||
)}
|
||||
title={provider.title}
|
||||
description={provider.description}
|
||||
status={getStatus(provider)}
|
||||
onConnect={() => handleConnect(provider)}
|
||||
onSelect={() => handleSelect(provider)}
|
||||
onDeselect={() => handleDeselect(provider)}
|
||||
onEdit={() => handleEdit(provider)}
|
||||
onDisconnect={
|
||||
getStatus(provider) !== "disconnected"
|
||||
? () => setDisconnectProvider(provider)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
@@ -447,13 +357,13 @@ export default function ImageGenerationContent() {
|
||||
|
||||
{activeProvider && (
|
||||
<modal.Provider>
|
||||
{getImageGenForm({
|
||||
modal: modal,
|
||||
imageProvider: activeProvider,
|
||||
existingProviders: llmProviders,
|
||||
existingConfig: editConfig || undefined,
|
||||
onSuccess: handleModalSuccess,
|
||||
})}
|
||||
<ImageGenerationConnectionModal
|
||||
modal={modal}
|
||||
imageProvider={activeProvider}
|
||||
existingProviders={llmProviders}
|
||||
existingConfig={editConfig || undefined}
|
||||
onSuccess={handleModalSuccess}
|
||||
/>
|
||||
</modal.Provider>
|
||||
)}
|
||||
</>
|
||||
@@ -7,14 +7,14 @@ import { FormField } from "@/refresh-components/form/FormField";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
|
||||
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
|
||||
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
|
||||
import {
|
||||
ImageGenFormBaseProps,
|
||||
ImageGenFormChildProps,
|
||||
ImageGenSubmitPayload,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
|
||||
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
|
||||
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
|
||||
} from "./types";
|
||||
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
|
||||
import { ImageProvider } from "../constants";
|
||||
import {
|
||||
parseAzureTargetUri,
|
||||
isValidAzureTargetUri,
|
||||
@@ -10,14 +10,14 @@ import {
|
||||
createImageGenerationConfig,
|
||||
updateImageGenerationConfig,
|
||||
fetchImageGenerationCredentials,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
|
||||
} from "@/lib/configuration/imageConfigurationService";
|
||||
import { APIFormFieldState } from "@/refresh-components/form/types";
|
||||
import {
|
||||
ImageGenFormWrapperProps,
|
||||
ImageGenFormChildProps,
|
||||
ImageGenSubmitPayload,
|
||||
FormValues,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
|
||||
} from "./types";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
export function ImageGenFormWrapper<T extends FormValues>({
|
||||
@@ -41,6 +41,9 @@ export function ImageGenFormWrapper<T extends FormValues>({
|
||||
const [errorMessage, setErrorMessage] = useState("");
|
||||
const [isLoadingCredentials, setIsLoadingCredentials] = useState(false);
|
||||
|
||||
// Form reset key for re-initialization
|
||||
const [formResetKey, setFormResetKey] = useState(0);
|
||||
|
||||
// Track merged initial values with fetched credentials
|
||||
const [mergedInitialValues, setMergedInitialValues] =
|
||||
useState<T>(initialValues);
|
||||
@@ -70,6 +73,7 @@ export function ImageGenFormWrapper<T extends FormValues>({
|
||||
imageProvider
|
||||
);
|
||||
setMergedInitialValues((prev) => ({ ...prev, ...credValues }));
|
||||
setFormResetKey((k) => k + 1);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
@@ -272,6 +276,7 @@ export function ImageGenFormWrapper<T extends FormValues>({
|
||||
|
||||
return (
|
||||
<Formik<T>
|
||||
key={formResetKey}
|
||||
initialValues={mergedInitialValues}
|
||||
onSubmit={handleSubmit}
|
||||
validationSchema={validationSchema}
|
||||
@@ -6,14 +6,14 @@ import { FormikField } from "@/refresh-components/form/FormikField";
|
||||
import { FormField } from "@/refresh-components/form/FormField";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
|
||||
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
|
||||
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
|
||||
import {
|
||||
ImageGenFormBaseProps,
|
||||
ImageGenFormChildProps,
|
||||
ImageGenSubmitPayload,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
|
||||
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
|
||||
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
|
||||
} from "./types";
|
||||
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
|
||||
import { ImageProvider } from "../constants";
|
||||
|
||||
// OpenAI form values - just API key
|
||||
interface OpenAIFormValues {
|
||||
@@ -6,14 +6,14 @@ import { FormField } from "@/refresh-components/form/FormField";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputFile from "@/refresh-components/inputs/InputFile";
|
||||
import InlineExternalLink from "@/refresh-components/InlineExternalLink";
|
||||
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
|
||||
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
|
||||
import {
|
||||
ImageGenFormBaseProps,
|
||||
ImageGenFormChildProps,
|
||||
ImageGenSubmitPayload,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
|
||||
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
|
||||
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
|
||||
} from "./types";
|
||||
import { ImageProvider } from "../constants";
|
||||
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
|
||||
|
||||
const VERTEXAI_PROVIDER_NAME = "vertex_ai";
|
||||
const VERTEXAI_DEFAULT_LOCATION = "global";
|
||||
@@ -1,8 +1,8 @@
|
||||
import React from "react";
|
||||
import { ImageGenFormBaseProps } from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
|
||||
import { OpenAIImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/OpenAIImageGenForm";
|
||||
import { AzureImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/AzureImageGenForm";
|
||||
import { VertexImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/VertexImageGenForm";
|
||||
import { ImageGenFormBaseProps } from "./types";
|
||||
import { OpenAIImageGenForm } from "./OpenAIImageGenForm";
|
||||
import { AzureImageGenForm } from "./AzureImageGenForm";
|
||||
import { VertexImageGenForm } from "./VertexImageGenForm";
|
||||
|
||||
/**
|
||||
* Factory function that routes to the correct provider-specific form
|
||||
@@ -0,0 +1,5 @@
|
||||
export * from "./types";
|
||||
export { ImageGenFormWrapper } from "./ImageGenFormWrapper";
|
||||
export { OpenAIImageGenForm } from "./OpenAIImageGenForm";
|
||||
export { AzureImageGenForm } from "./AzureImageGenForm";
|
||||
export { getImageGenForm } from "./getImageGenForm";
|
||||
@@ -1,10 +1,10 @@
|
||||
import { FormikProps } from "formik";
|
||||
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
|
||||
import { ImageProvider } from "../constants";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
ImageGenerationConfigView,
|
||||
ImageGenerationCredentials,
|
||||
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
|
||||
} from "@/lib/configuration/imageConfigurationService";
|
||||
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
|
||||
import { APIFormFieldState } from "@/refresh-components/form/types";
|
||||
|
||||
@@ -1 +1,22 @@
|
||||
export { default } from "@/refresh-pages/admin/ImageGenerationPage";
|
||||
"use client";
|
||||
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import ImageGenerationContent from "./ImageGenerationContent";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTES.IMAGE_GENERATION;
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Settings for in-chat image generation."
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<ImageGenerationContent />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
export { default } from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
"use client";
|
||||
|
||||
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
export default function Page() {
|
||||
return <LLMConfigurationPage />;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { Form, Formik } from "formik";
|
||||
import { mutate } from "swr";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import * as Yup from "yup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
@@ -121,10 +119,6 @@ export const DocumentSetCreationForm = ({
|
||||
? "Successfully updated document set!"
|
||||
: "Successfully created document set!"
|
||||
);
|
||||
await Promise.all([
|
||||
mutate(SWR_KEYS.documentSets),
|
||||
mutate(SWR_KEYS.documentSetsEditable),
|
||||
]);
|
||||
onClose();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { DocumentSetSummary } from "@/lib/types";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
const DOCUMENT_SETS_URL = "/api/manage/document-set";
|
||||
const GET_EDITABLE_DOCUMENT_SETS_URL =
|
||||
"/api/manage/document-set?get_editable=true";
|
||||
|
||||
export function refreshDocumentSets() {
|
||||
mutate(SWR_KEYS.documentSets);
|
||||
mutate(DOCUMENT_SETS_URL);
|
||||
}
|
||||
|
||||
export function useDocumentSets(getEditable: boolean = false) {
|
||||
const url = getEditable
|
||||
? SWR_KEYS.documentSetsEditable
|
||||
: SWR_KEYS.documentSets;
|
||||
const url = getEditable ? GET_EDITABLE_DOCUMENT_SETS_URL : DOCUMENT_SETS_URL;
|
||||
|
||||
const swrResponse = useSWR<DocumentSetSummary[]>(url, errorHandlingFetcher, {
|
||||
refreshInterval: 5000, // 5 seconds
|
||||
|
||||
@@ -182,7 +182,8 @@ export async function* sendMessage({
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
const data = await response.json().catch(() => ({}));
|
||||
throw new Error(data.detail ?? `HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
yield* handleSSEStream<PacketType>(response, signal);
|
||||
|
||||
@@ -11,7 +11,6 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { ChatSessionMinimal } from "@/app/ee/admin/performance/usage/types";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { timestampToReadableDate } from "@/lib/dateUtils";
|
||||
import { Dispatch, SetStateAction, useCallback, useState } from "react";
|
||||
import { Feedback, TaskStatus } from "@/lib/types";
|
||||
@@ -102,32 +101,34 @@ function SelectFeedbackType({
|
||||
onValueChange: (value: Feedback | "all") => void;
|
||||
}) {
|
||||
return (
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" className="font-medium">
|
||||
<div>
|
||||
<Text as="p" className="my-auto mr-2 font-medium mb-1">
|
||||
Feedback Type
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={value}
|
||||
onValueChange={onValueChange as (value: string) => void}
|
||||
>
|
||||
<InputSelect.Trigger />
|
||||
<div className="max-w-sm space-y-6">
|
||||
<InputSelect
|
||||
value={value}
|
||||
onValueChange={onValueChange as (value: string) => void}
|
||||
>
|
||||
<InputSelect.Trigger />
|
||||
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="all" icon={SvgMinusCircle}>
|
||||
Any
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="like" icon={SvgThumbsUp}>
|
||||
Like
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="dislike" icon={SvgThumbsDown}>
|
||||
Dislike
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="mixed" icon={SvgMinus}>
|
||||
Mixed
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="all" icon={SvgMinusCircle}>
|
||||
Any
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="like" icon={SvgThumbsUp}>
|
||||
Like
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="dislike" icon={SvgThumbsDown}>
|
||||
Dislike
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="mixed" icon={SvgMinus}>
|
||||
Mixed
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -184,61 +185,60 @@ function PreviousQueryHistoryExportsModal({
|
||||
onClose={() => setShowModal(false)}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Generated At</TableHead>
|
||||
<TableHead>Start Range</TableHead>
|
||||
<TableHead>End Range</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Download</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{paginatedTasks.map((task, index) => (
|
||||
<TableRow key={index}>
|
||||
<TableCell>
|
||||
{humanReadableFormatWithTime(task.startTime)}
|
||||
</TableCell>
|
||||
<TableCell>{task.start.toDateString()}</TableCell>
|
||||
<TableCell>{task.end.toDateString()}</TableCell>
|
||||
<TableCell>
|
||||
<ExportBadge status={task.status} />
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
variant="default"
|
||||
prominence="tertiary"
|
||||
icon={SvgDownloadCloud}
|
||||
size="sm"
|
||||
disabled={task.status !== "SUCCESS"}
|
||||
tooltip={
|
||||
task.status !== "SUCCESS"
|
||||
? "Export is not yet ready"
|
||||
: undefined
|
||||
}
|
||||
href={
|
||||
task.status === "SUCCESS"
|
||||
? withRequestId(
|
||||
<div className="flex flex-col w-full">
|
||||
<div className="flex flex-1">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Generated At</TableHead>
|
||||
<TableHead>Start Range</TableHead>
|
||||
<TableHead>End Range</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Download</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{paginatedTasks.map((task, index) => (
|
||||
<TableRow key={index}>
|
||||
<TableCell>
|
||||
{humanReadableFormatWithTime(task.startTime)}
|
||||
</TableCell>
|
||||
<TableCell>{task.start.toDateString()}</TableCell>
|
||||
<TableCell>{task.end.toDateString()}</TableCell>
|
||||
<TableCell>
|
||||
<ExportBadge status={task.status} />
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{task.status === "SUCCESS" ? (
|
||||
<a
|
||||
className="flex justify-center"
|
||||
href={withRequestId(
|
||||
DOWNLOAD_QUERY_HISTORY_URL,
|
||||
task.taskId
|
||||
)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
)}
|
||||
>
|
||||
<SvgDownloadCloud className="h-4 w-4 text-action-link-05" />
|
||||
</a>
|
||||
) : (
|
||||
<SvgDownloadCloud className="h-4 w-4 text-action-link-05 opacity-20" />
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
<Section>
|
||||
<PageSelector
|
||||
currentPage={taskPage}
|
||||
totalPages={totalTaskPages}
|
||||
onPageChange={setTaskPage}
|
||||
/>
|
||||
</Section>
|
||||
<div className="flex mt-3">
|
||||
<div className="mx-auto">
|
||||
<PageSelector
|
||||
currentPage={taskPage}
|
||||
totalPages={totalTaskPages}
|
||||
onPageChange={setTaskPage}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
@@ -330,48 +330,48 @@ export function QueryHistoryTable() {
|
||||
</div>
|
||||
</div>
|
||||
<Separator />
|
||||
<Section>
|
||||
<Table className="mt-5">
|
||||
<TableHeader>
|
||||
<Table className="mt-5">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>First User Message</TableHead>
|
||||
<TableHead>First AI Response</TableHead>
|
||||
<TableHead>Feedback</TableHead>
|
||||
<TableHead>User</TableHead>
|
||||
<TableHead>Persona</TableHead>
|
||||
<TableHead>Date</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
{isLoading ? (
|
||||
<TableBody>
|
||||
<TableRow>
|
||||
<TableHead>First User Message</TableHead>
|
||||
<TableHead>First AI Response</TableHead>
|
||||
<TableHead>Feedback</TableHead>
|
||||
<TableHead>User</TableHead>
|
||||
<TableHead>Persona</TableHead>
|
||||
<TableHead>Date</TableHead>
|
||||
<TableCell colSpan={6} className="text-center">
|
||||
<ThreeDotsLoader />
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
{isLoading ? (
|
||||
<TableBody>
|
||||
<TableRow>
|
||||
<TableCell colSpan={6} className="text-center">
|
||||
<ThreeDotsLoader />
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
) : (
|
||||
<TableBody>
|
||||
{chatSessionData?.map((chatSessionMinimal) => (
|
||||
<QueryHistoryTableRow
|
||||
key={chatSessionMinimal.id}
|
||||
chatSessionMinimal={chatSessionMinimal}
|
||||
/>
|
||||
))}
|
||||
</TableBody>
|
||||
)}
|
||||
</Table>
|
||||
</TableBody>
|
||||
) : (
|
||||
<TableBody>
|
||||
{chatSessionData?.map((chatSessionMinimal) => (
|
||||
<QueryHistoryTableRow
|
||||
key={chatSessionMinimal.id}
|
||||
chatSessionMinimal={chatSessionMinimal}
|
||||
/>
|
||||
))}
|
||||
</TableBody>
|
||||
)}
|
||||
</Table>
|
||||
|
||||
{chatSessionData && (
|
||||
<Section>
|
||||
{chatSessionData && (
|
||||
<div className="mt-3 flex">
|
||||
<div className="mx-auto">
|
||||
<PageSelector
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={goToPage}
|
||||
/>
|
||||
</Section>
|
||||
)}
|
||||
</Section>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardSection>
|
||||
|
||||
{showModal && (
|
||||
|
||||
@@ -330,14 +330,6 @@
|
||||
letter-spacing: 0px;
|
||||
}
|
||||
|
||||
.font-secondary-mono-label {
|
||||
font-family: var(--font-dm-mono);
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
line-height: 16px;
|
||||
letter-spacing: 0px;
|
||||
}
|
||||
|
||||
/* FIGURE STYLES */
|
||||
|
||||
.font-figure-small-label {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import useSWR from "swr";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import {
|
||||
UserSpecificAgentPreference,
|
||||
UserSpecificAgentPreferences,
|
||||
@@ -10,6 +9,7 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { useCallback } from "react";
|
||||
|
||||
// TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
|
||||
const AGENT_PREFERENCES_URL = "/api/user/assistant/preferences";
|
||||
|
||||
// TODO: rename to agent — https://linear.app/onyx-app/issue/ENG-3766
|
||||
const buildUpdateAgentPreferenceUrl = (agentId: number) =>
|
||||
@@ -21,11 +21,10 @@ const buildUpdateAgentPreferenceUrl = (agentId: number) =>
|
||||
*/
|
||||
export default function useAgentPreferences() {
|
||||
const { data, mutate } = useSWR<UserSpecificAgentPreferences>(
|
||||
SWR_KEYS.agentPreferences,
|
||||
AGENT_PREFERENCES_URL,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import useSWR from "swr";
|
||||
import { useState, useEffect, useMemo, useCallback } from "react";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import {
|
||||
MinimalPersonaSnapshot,
|
||||
FullPersona,
|
||||
@@ -37,11 +36,10 @@ import useChatSessions from "./useChatSessions";
|
||||
*/
|
||||
export function useAgents() {
|
||||
const { data, error, mutate } = useSWR<MinimalPersonaSnapshot[]>(
|
||||
SWR_KEYS.personas,
|
||||
"/api/persona",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
@@ -78,11 +76,10 @@ export function useAgents() {
|
||||
*/
|
||||
export function useAgent(agentId: number | null) {
|
||||
const { data, error, isLoading, mutate } = useSWR<FullPersona>(
|
||||
agentId ? SWR_KEYS.persona(agentId) : null,
|
||||
agentId ? `/api/persona/${agentId}` : null,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import useSWR from "swr";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
interface AuthTypeAPIResponse {
|
||||
auth_type: string;
|
||||
@@ -55,12 +54,11 @@ export function useAuthTypeMetadata(): {
|
||||
error: Error | undefined;
|
||||
} {
|
||||
const { data, error, isLoading } = useSWR<AuthTypeMetadata>(
|
||||
SWR_KEYS.authType,
|
||||
"/api/auth/type",
|
||||
fetchAuthTypeMetadata,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnReconnect: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 30_000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import useSWR from "swr";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
/**
|
||||
* Hook to fetch all available tools from the backend.
|
||||
@@ -25,12 +24,10 @@ import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
*/
|
||||
export function useAvailableTools() {
|
||||
const { data, error, mutate } = useSWR<ToolSnapshot[]>(
|
||||
SWR_KEYS.tools,
|
||||
"/api/tool",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
revalidateOnFocus: true,
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import useSWR from "swr";
|
||||
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import {
|
||||
BillingInformation,
|
||||
SubscriptionStatus,
|
||||
@@ -17,15 +16,14 @@ import {
|
||||
*/
|
||||
export function useBillingInformation() {
|
||||
const url = NEXT_PUBLIC_CLOUD_ENABLED
|
||||
? SWR_KEYS.billingInformationCloud
|
||||
: SWR_KEYS.billingInformationSelfHosted;
|
||||
? "/api/tenants/billing-information"
|
||||
: "/api/admin/billing/billing-information";
|
||||
|
||||
const { data, error, mutate, isLoading } = useSWR<
|
||||
BillingInformation | SubscriptionStatus
|
||||
>(url, errorHandlingFetcher, {
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnReconnect: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 30000,
|
||||
shouldRetryOnError: false,
|
||||
keepPreviousData: true,
|
||||
|
||||
@@ -901,6 +901,11 @@ export default function useChatController({
|
||||
});
|
||||
}
|
||||
}
|
||||
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
|
||||
// catch block replaces the thinking placeholder with an error message.
|
||||
if (stack.error) {
|
||||
throw new Error(stack.error);
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.log("Error:", e);
|
||||
const errorMsg = e.message;
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
import useSWRInfinite from "swr/infinite";
|
||||
import { ChatSession, ChatSessionSharedStatus } from "@/app/app/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
|
||||
import useAppFocus from "./useAppFocus";
|
||||
import { useAgents } from "./useAgents";
|
||||
@@ -147,7 +146,7 @@ export default function useChatSessions(): UseChatSessionsOutput {
|
||||
|
||||
// First page — no cursor
|
||||
if (pageIndex === 0) {
|
||||
return `${SWR_KEYS.chatSessions}?page_size=${PAGE_SIZE}`;
|
||||
return `/api/chat/get-user-chat-sessions?page_size=${PAGE_SIZE}`;
|
||||
}
|
||||
|
||||
// Subsequent pages — cursor from the last session of the previous page
|
||||
@@ -159,7 +158,7 @@ export default function useChatSessions(): UseChatSessionsOutput {
|
||||
page_size: PAGE_SIZE.toString(),
|
||||
before: lastSession.time_updated,
|
||||
});
|
||||
return `${SWR_KEYS.chatSessions}?${params.toString()}`;
|
||||
return `/api/chat/get-user-chat-sessions?${params.toString()}`;
|
||||
};
|
||||
|
||||
const { data, error, setSize, mutate } = useSWRInfinite<ChatSessionsResponse>(
|
||||
@@ -167,7 +166,6 @@ export default function useChatSessions(): UseChatSessionsOutput {
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
revalidateFirstPage: true,
|
||||
revalidateAll: false,
|
||||
dedupingInterval: 30000,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import useSWR, { type KeyedMutator } from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { User } from "@/lib/types";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
/**
|
||||
* Fetches the current authenticated user via SWR (`/api/me`).
|
||||
@@ -30,12 +29,11 @@ export function useCurrentUser(): {
|
||||
userError: (Error & { status?: number }) | undefined;
|
||||
} {
|
||||
const { data, mutate, error } = useSWR<User>(
|
||||
SWR_KEYS.me,
|
||||
"/api/me",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnReconnect: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 30_000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import useSWR from "swr";
|
||||
import { fetchExecutionLogs } from "@/refresh-pages/admin/HooksPage/svc";
|
||||
import type { HookExecutionRecord } from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
const ONE_HOUR_MS = 60 * 60 * 1000;
|
||||
const THIRTY_DAYS_MS = 30 * 24 * 60 * 60 * 1000;
|
||||
|
||||
interface UseHookExecutionLogsResult {
|
||||
isLoading: boolean;
|
||||
error: Error | undefined;
|
||||
hasRecentErrors: boolean;
|
||||
recentErrors: HookExecutionRecord[];
|
||||
olderErrors: HookExecutionRecord[];
|
||||
}
|
||||
|
||||
export function useHookExecutionLogs(
|
||||
hookId: number,
|
||||
limit = 10
|
||||
): UseHookExecutionLogsResult {
|
||||
const { data, isLoading, error } = useSWR(
|
||||
["hook-execution-logs", hookId, limit],
|
||||
() => fetchExecutionLogs(hookId, limit),
|
||||
{ refreshInterval: 60_000 }
|
||||
);
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const recentErrors =
|
||||
data?.filter(
|
||||
(log) => now - new Date(log.created_at).getTime() < ONE_HOUR_MS
|
||||
) ?? [];
|
||||
|
||||
const olderErrors =
|
||||
data?.filter((log) => {
|
||||
const age = now - new Date(log.created_at).getTime();
|
||||
return age >= ONE_HOUR_MS && age < THIRTY_DAYS_MS;
|
||||
}) ?? [];
|
||||
|
||||
const hasRecentErrors = recentErrors.length > 0;
|
||||
|
||||
return { isLoading, error, hasRecentErrors, recentErrors, olderErrors };
|
||||
}
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import {
|
||||
LLMProviderDescriptor,
|
||||
LLMProviderResponse,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
|
||||
/**
|
||||
* Fetches configured LLM providers accessible to the current user.
|
||||
@@ -45,14 +45,13 @@ import {
|
||||
export function useLLMProviders(personaId?: number) {
|
||||
const url =
|
||||
personaId !== undefined
|
||||
? SWR_KEYS.llmProvidersForPersona(personaId)
|
||||
: SWR_KEYS.llmProviders;
|
||||
? `/api/llm/persona/${personaId}/providers`
|
||||
: "/api/llm/provider";
|
||||
|
||||
const { data, error, mutate } = useSWR<
|
||||
LLMProviderResponse<LLMProviderDescriptor>
|
||||
>(url, errorHandlingFetcher, {
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
});
|
||||
|
||||
@@ -89,11 +88,10 @@ export function useLLMProviders(personaId?: number) {
|
||||
*/
|
||||
export function useAdminLLMProviders() {
|
||||
const { data, error, mutate } = useSWR<LLMProviderResponse<LLMProviderView>>(
|
||||
SWR_KEYS.adminLlmProviders,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
@@ -143,11 +141,12 @@ export function useAdminLLMProviders() {
|
||||
*/
|
||||
export function useWellKnownLLMProvider(providerEndpoint: string | null) {
|
||||
const { data, error, isLoading } = useSWR<WellKnownLLMProviderDescriptor>(
|
||||
providerEndpoint ? SWR_KEYS.wellKnownLlmProvider(providerEndpoint) : null,
|
||||
providerEndpoint
|
||||
? `/api/admin/llm/built-in/options/${providerEndpoint}`
|
||||
: null,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
@@ -166,11 +165,10 @@ export function useWellKnownLLMProviders() {
|
||||
isLoading,
|
||||
mutate,
|
||||
} = useSWR<WellKnownLLMProviderDescriptor[]>(
|
||||
SWR_KEYS.wellKnownLlmProviders,
|
||||
"/api/admin/llm/built-in/options",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -3,7 +3,6 @@ import useSWR from "swr";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { LicenseStatus } from "@/lib/billing/interfaces";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
/**
|
||||
* Hook to fetch license status for self-hosted deployments.
|
||||
@@ -11,7 +10,7 @@ import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
* Skips the fetch on cloud deployments (uses tenant auth instead).
|
||||
*/
|
||||
export function useLicense() {
|
||||
const url = NEXT_PUBLIC_CLOUD_ENABLED ? null : SWR_KEYS.license;
|
||||
const url = NEXT_PUBLIC_CLOUD_ENABLED ? null : "/api/license";
|
||||
|
||||
const { data, error, mutate, isLoading } = useSWR<LicenseStatus>(
|
||||
url,
|
||||
@@ -19,7 +18,6 @@ export function useLicense() {
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnReconnect: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 30000,
|
||||
shouldRetryOnError: false,
|
||||
keepPreviousData: true,
|
||||
|
||||
@@ -3,17 +3,11 @@
|
||||
import useSWR from "swr";
|
||||
import { InputPrompt } from "@/app/app/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
export default function usePromptShortcuts() {
|
||||
const { data, error, isLoading, mutate } = useSWR<InputPrompt[]>(
|
||||
SWR_KEYS.promptShortcuts,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
"/api/input_prompt",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const promptShortcuts = data ?? [];
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import {
|
||||
Settings,
|
||||
EnterpriseSettings,
|
||||
@@ -33,12 +32,11 @@ export function useSettings(): {
|
||||
error: Error | undefined;
|
||||
} {
|
||||
const { data, error, isLoading } = useSWR<Settings>(
|
||||
SWR_KEYS.settings,
|
||||
"/api/settings",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnReconnect: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 30_000,
|
||||
errorRetryInterval: SETTINGS_ERROR_RETRY_INTERVAL,
|
||||
}
|
||||
@@ -63,12 +61,11 @@ export function useEnterpriseSettings(eeEnabledRuntime: boolean): {
|
||||
const shouldFetch = EE_ENABLED || eeEnabledRuntime;
|
||||
|
||||
const { data, error, isLoading } = useSWR<EnterpriseSettings>(
|
||||
shouldFetch ? SWR_KEYS.enterpriseSettings : null,
|
||||
shouldFetch ? "/api/enterprise-settings" : null,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnReconnect: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 30_000,
|
||||
errorRetryInterval: SETTINGS_ERROR_RETRY_INTERVAL,
|
||||
// Referential equality instead of SWR's default deep comparison.
|
||||
@@ -92,12 +89,11 @@ export function useCustomAnalyticsScript(
|
||||
const shouldFetch = EE_ENABLED || eeEnabledRuntime;
|
||||
|
||||
const { data } = useSWR<string>(
|
||||
shouldFetch ? SWR_KEYS.customAnalyticsScript : null,
|
||||
shouldFetch ? "/api/enterprise-settings/custom-analytics-script" : null,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnReconnect: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60_000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import useSWR from "swr";
|
||||
import { Tag } from "@/lib/types";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
interface TagsResponse {
|
||||
tags: Tag[];
|
||||
@@ -19,11 +18,10 @@ interface TagsResponse {
|
||||
*/
|
||||
export default function useTags() {
|
||||
const { data, error, mutate } = useSWR<TagsResponse>(
|
||||
SWR_KEYS.tags,
|
||||
"/api/query/valid-tags",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
export interface VoiceProviderView {
|
||||
id: number;
|
||||
@@ -15,13 +14,14 @@ export interface VoiceProviderView {
|
||||
target_uri: string | null;
|
||||
}
|
||||
|
||||
const VOICE_PROVIDERS_URL = "/api/admin/voice/providers";
|
||||
|
||||
export function useVoiceProviders() {
|
||||
const { data, error, isLoading, mutate } = useSWR<VoiceProviderView[]>(
|
||||
SWR_KEYS.voiceProviders,
|
||||
VOICE_PROVIDERS_URL,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -6,8 +6,6 @@ import { INTERNAL_URL, IS_DEV } from "@/lib/constants";
|
||||
const TARGET_SAMPLE_RATE = 24000;
|
||||
const CHUNK_INTERVAL_MS = 250;
|
||||
const DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS = 1500;
|
||||
// When VAD-based auto-stop is disabled, force-stop after this much silence as a fallback
|
||||
const SILENCE_FALLBACK_TIMEOUT_MS = 10000;
|
||||
|
||||
interface TranscriptMessage {
|
||||
type: "transcript" | "error";
|
||||
@@ -60,8 +58,6 @@ class VoiceRecorderSession {
|
||||
private finalTranscriptDelivered = false;
|
||||
private lastDeliveredFinalText: string | null = null;
|
||||
private lastDeliveredFinalAtMs = 0;
|
||||
// Fallback timer: force-stop after extended silence when VAD auto-stop is disabled
|
||||
private silenceFallbackTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
// Callbacks to update React state
|
||||
private onTranscriptChange: (text: string) => void;
|
||||
@@ -178,8 +174,6 @@ class VoiceRecorderSession {
|
||||
async stop(): Promise<string | null> {
|
||||
if (!this.isActive) return this.transcript || null;
|
||||
|
||||
this.resetSilenceFallbackTimer();
|
||||
|
||||
// Stop audio capture
|
||||
if (this.sendInterval) {
|
||||
clearInterval(this.sendInterval);
|
||||
@@ -225,7 +219,6 @@ class VoiceRecorderSession {
|
||||
}
|
||||
|
||||
cleanup(): void {
|
||||
this.resetSilenceFallbackTimer();
|
||||
if (this.sendInterval) clearInterval(this.sendInterval);
|
||||
if (this.scriptNode) this.scriptNode.disconnect();
|
||||
if (this.sourceNode) this.sourceNode.disconnect();
|
||||
@@ -281,23 +274,6 @@ class VoiceRecorderSession {
|
||||
});
|
||||
}
|
||||
|
||||
private resetSilenceFallbackTimer(): void {
|
||||
if (this.silenceFallbackTimer) {
|
||||
clearTimeout(this.silenceFallbackTimer);
|
||||
this.silenceFallbackTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
private startSilenceFallbackTimer(): void {
|
||||
this.resetSilenceFallbackTimer();
|
||||
this.silenceFallbackTimer = setTimeout(() => {
|
||||
// 10s of silence with no new speech — force-stop as a safety fallback
|
||||
if (this.isActive && this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
}, SILENCE_FALLBACK_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
private handleMessage = (event: MessageEvent): void => {
|
||||
try {
|
||||
const data: TranscriptMessage = JSON.parse(event.data);
|
||||
@@ -305,53 +281,47 @@ class VoiceRecorderSession {
|
||||
if (data.type === "transcript") {
|
||||
if (data.text) {
|
||||
this.transcript = data.text;
|
||||
// Only push live updates to React while actively recording.
|
||||
// After stop(), the final transcript is returned via stopResolver
|
||||
// instead — this prevents stale text from reappearing in the
|
||||
// input box when the user clears it and starts a new recording.
|
||||
if (this.isActive) {
|
||||
this.onTranscriptChange(data.text);
|
||||
}
|
||||
this.onTranscriptChange(data.text);
|
||||
}
|
||||
|
||||
if (data.is_final && data.text) {
|
||||
// Resolve stop promise if waiting — must run even after stop()
|
||||
// so the caller receives the final transcript.
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
// VAD detected silence - trigger callback (only once per utterance)
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.autoStopOnSilence &&
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
}
|
||||
|
||||
// Skip VAD logic if session is no longer active
|
||||
if (!this.isActive) return;
|
||||
|
||||
// Auto-stop recording if enabled
|
||||
if (this.autoStopOnSilence) {
|
||||
// VAD detected silence — auto-stop and trigger callback
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
}
|
||||
|
||||
// Trigger stop callback to update React state
|
||||
if (this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
} else {
|
||||
// Auto-stop disabled (push-to-talk): ignore VAD, keep recording.
|
||||
// Start/reset a 10s fallback timer — if no new speech arrives,
|
||||
// force-stop to avoid recording silence indefinitely.
|
||||
this.startSilenceFallbackTimer();
|
||||
// If not auto-stopping, reset for next utterance
|
||||
this.transcript = "";
|
||||
this.finalTranscriptDelivered = false;
|
||||
this.onTranscriptChange("");
|
||||
this.resetBackendTranscript();
|
||||
}
|
||||
|
||||
// Resolve stop promise if waiting
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
}
|
||||
} else if (data.type === "error") {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
interface VoiceStatus {
|
||||
stt_enabled: boolean;
|
||||
@@ -9,11 +8,10 @@ interface VoiceStatus {
|
||||
|
||||
export function useVoiceStatus() {
|
||||
const { data, error, isLoading } = useSWR<VoiceStatus>(
|
||||
SWR_KEYS.voiceStatus,
|
||||
"/api/voice/status",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -7,7 +7,6 @@ import * as Sentry from "@sentry/nextjs";
|
||||
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
|
||||
Sentry.init({
|
||||
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
|
||||
release: process.env.SENTRY_RELEASE,
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
|
||||
@@ -160,31 +160,6 @@ export const formatDateShort = (dateStr: string | null | undefined): string => {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Format an ISO timestamp as "YYYY/MM/DD HH:MM:SS" (24-hour, local time).
|
||||
* Intended for log displays where full precision is needed.
|
||||
*/
|
||||
export function formatDateTimeLog(iso: string): string {
|
||||
const d = new Date(iso);
|
||||
const pad = (n: number) => String(n).padStart(2, "0");
|
||||
return `${d.getFullYear()}/${pad(d.getMonth() + 1)}/${pad(d.getDate())} ${pad(
|
||||
d.getHours()
|
||||
)}:${pad(d.getMinutes())}:${pad(d.getSeconds())}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format an ISO timestamp as "HH:MM:SS" (24-hour, local time).
|
||||
* Intended for compact time-only displays.
|
||||
*/
|
||||
export function formatTimeOnly(iso: string): string {
|
||||
return new Date(iso).toLocaleTimeString(undefined, {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
second: "2-digit",
|
||||
hour12: false,
|
||||
});
|
||||
}
|
||||
|
||||
export function formatMmDdYyyy(d: string): string {
|
||||
const date = new Date(d);
|
||||
return `${date.getMonth() + 1}/${date.getDate()}/${date.getFullYear()}`;
|
||||
|
||||
@@ -96,54 +96,6 @@ describe("LLM resolver helpers", () => {
|
||||
});
|
||||
});
|
||||
|
||||
test("prefers provider by name when multiple share the same type", () => {
|
||||
const providers: LLMProviderDescriptor[] = [
|
||||
makeProvider({
|
||||
id: 1,
|
||||
name: "Anthropic",
|
||||
provider: "anthropic",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "claude-sonnet-4-5",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
makeProvider({
|
||||
id: 2,
|
||||
name: "PersonalAnthropicToken",
|
||||
provider: "anthropic",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "claude-sonnet-4-5",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const descriptor = getValidLlmDescriptorForProviders(
|
||||
structureValue(
|
||||
"PersonalAnthropicToken",
|
||||
"anthropic",
|
||||
"claude-sonnet-4-5"
|
||||
),
|
||||
providers
|
||||
);
|
||||
|
||||
expect(descriptor).toEqual({
|
||||
name: "PersonalAnthropicToken",
|
||||
provider: "anthropic",
|
||||
modelName: "claude-sonnet-4-5",
|
||||
});
|
||||
});
|
||||
|
||||
test("uses first provider with models when no explicit default exists", () => {
|
||||
const providers: LLMProviderDescriptor[] = [
|
||||
makeProvider({
|
||||
|
||||
@@ -603,16 +603,13 @@ export function getValidLlmDescriptorForProviders(
|
||||
// This ensures we don't incorrectly match a model to the wrong provider
|
||||
// when the same model name exists across multiple providers (e.g., gpt-5 in Azure and OpenAI)
|
||||
if (model.provider && model.provider.length > 0) {
|
||||
const hasModel = (p: LLMProviderDescriptor) =>
|
||||
p.model_configurations.some((mc) => mc.name === model.modelName);
|
||||
const typeMatches = llmProviders.filter(
|
||||
(p) => p.provider === model.provider && hasModel(p)
|
||||
const matchingProvider = llmProviders.find(
|
||||
(p) =>
|
||||
p.provider === model.provider &&
|
||||
p.model_configurations
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
.includes(model.modelName)
|
||||
);
|
||||
// When multiple providers share the same type (e.g., two "anthropic"
|
||||
// providers with different API keys), prefer the one whose name matches
|
||||
// the user's explicit selection to avoid silently switching providers.
|
||||
const matchingProvider =
|
||||
typeMatches.find((p) => p.name === model.name) ?? typeMatches[0];
|
||||
if (matchingProvider) {
|
||||
return {
|
||||
...model,
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import useSWR from "swr";
|
||||
import { DocumentSetSummary } from "@/lib/types";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
export function useDocumentSets() {
|
||||
const { data, error, mutate } = useSWR<DocumentSetSummary[]>(
|
||||
SWR_KEYS.documentSets,
|
||||
"/api/manage/document-set",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import useSWR from "swr";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
export function useLLMProviderOptions() {
|
||||
const { data, error, mutate } = useSWR<
|
||||
WellKnownLLMProviderDescriptor[] | undefined
|
||||
>(SWR_KEYS.wellKnownLlmProviders, errorHandlingFetcher, {
|
||||
>("/api/admin/llm/built-in/options", errorHandlingFetcher, {
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 60000,
|
||||
dedupingInterval: 60000, // Dedupe requests within 1 minute
|
||||
});
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import useSWR from "swr";
|
||||
import { Project } from "@/app/app/projects/projectsService";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
export function useProjects() {
|
||||
const { data, error, mutate } = useSWR<Project[]>(
|
||||
SWR_KEYS.userProjects,
|
||||
"/api/user/projects",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateIfStale: false,
|
||||
dedupingInterval: 30000,
|
||||
}
|
||||
);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user