mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-01 13:02:42 +00:00
Compare commits
32 Commits
multi-mode
...
bo/hook_ee
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39a3ee1a0a | ||
|
|
a1c3a68ba4 | ||
|
|
4fb175ae65 | ||
|
|
800ad326df | ||
|
|
6b920e8a3e | ||
|
|
ef3760796d | ||
|
|
fa5b90df92 | ||
|
|
53953ac4fa | ||
|
|
26bb5c990c | ||
|
|
27b4ed301f | ||
|
|
93ec270ccc | ||
|
|
9e2d6c8a1d | ||
|
|
fc934214d0 | ||
|
|
48fc45a0cd | ||
|
|
009266e53e | ||
|
|
ffb9df7308 | ||
|
|
b0f5e0b8d9 | ||
|
|
43aea5d614 | ||
|
|
593d82f431 | ||
|
|
adf5691b5f | ||
|
|
c1a8a5bd83 | ||
|
|
8fd486da99 | ||
|
|
4bda4d3637 | ||
|
|
13c25eadad | ||
|
|
1f244e6388 | ||
|
|
18b0416d30 | ||
|
|
4bc0bc1efb | ||
|
|
1555217061 | ||
|
|
d177a833f0 | ||
|
|
086997d3c5 | ||
|
|
dccec78397 | ||
|
|
0123133621 |
6
.github/workflows/deployment.yml
vendored
6
.github/workflows/deployment.yml
vendored
@@ -704,6 +704,9 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
SENTRY_RELEASE=${{ github.sha }}
|
||||
secrets: |
|
||||
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
@@ -786,6 +789,9 @@ jobs:
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
SENTRY_RELEASE=${{ github.sha }}
|
||||
secrets: |
|
||||
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.hooks",
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
|
||||
@@ -55,6 +55,15 @@ ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
|
||||
@@ -13,6 +13,7 @@ from redis.lock import Lock as RedisLock
|
||||
from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
@@ -29,9 +30,10 @@ from shared_configs.configs import TENANT_ID_PREFIX
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
|
||||
# Time limits sized for worst-case batch: _MAX_TENANTS_PER_RUN × ~90s + buffer.
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 15 # 15 minutes
|
||||
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
|
||||
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -91,8 +93,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN)
|
||||
if batch_size < tenants_to_provision:
|
||||
task_logger.info(
|
||||
f"Capping batch to {batch_size} "
|
||||
f"(need {tenants_to_provision}, will catch up next cycle)"
|
||||
f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)"
|
||||
)
|
||||
|
||||
provisioned = 0
|
||||
@@ -103,12 +104,14 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
provisioned += 1
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, "
|
||||
"continuing with remaining tenants"
|
||||
f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants"
|
||||
)
|
||||
|
||||
task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded")
|
||||
|
||||
# Migrate any pool tenants that were provisioned before a new migration was deployed
|
||||
_migrate_stale_pool_tenants()
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
@@ -121,6 +124,46 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
)
|
||||
|
||||
|
||||
def _migrate_stale_pool_tenants() -> None:
|
||||
"""
|
||||
Run alembic upgrade head on all pool tenants. Since alembic upgrade head is
|
||||
idempotent, tenants already at head are a fast no-op. This ensures pool
|
||||
tenants are always current so that signup doesn't hit schema mismatches
|
||||
(e.g. missing columns added after the tenant was pre-provisioned).
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
pool_tenants = db_session.query(AvailableTenant).all()
|
||||
tenant_ids = [t.tenant_id for t in pool_tenants]
|
||||
|
||||
if not tenant_ids:
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations"
|
||||
)
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
try:
|
||||
run_alembic_migrations(tenant_id)
|
||||
new_version = get_current_alembic_version(tenant_id)
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
if tenant and tenant.alembic_version != new_version:
|
||||
task_logger.info(
|
||||
f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}"
|
||||
)
|
||||
tenant.alembic_version = new_version
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to migrate pool tenant {tenant_id}, skipping"
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> bool:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
|
||||
@@ -69,5 +69,7 @@ EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
|
||||
"/admin/token-rate-limits",
|
||||
# Evals
|
||||
"/evals",
|
||||
# Hook extensions
|
||||
"/admin/hooks",
|
||||
}
|
||||
)
|
||||
|
||||
0
backend/ee/onyx/hooks/__init__.py
Normal file
0
backend/ee/onyx/hooks/__init__.py
Normal file
385
backend/ee/onyx/hooks/executor.py
Normal file
385
backend/ee/onyx/hooks/executor.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
|
||||
|
||||
def _execute_hook_impl(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
@@ -15,6 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.features.hooks.api import router as hook_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
@@ -138,6 +139,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, evals_router)
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
0
backend/ee/onyx/server/features/__init__.py
Normal file
0
backend/ee/onyx/server/features/__init__.py
Normal file
0
backend/ee/onyx/server/features/hooks/__init__.py
Normal file
0
backend/ee/onyx/server/features/hooks/__init__.py
Normal file
@@ -99,6 +99,26 @@ async def get_or_provision_tenant(
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# Run migrations to ensure the pre-provisioned tenant schema is current.
|
||||
# Pool tenants may have been created before a new migration was deployed.
|
||||
# Capture as a non-optional local so mypy can type the lambda correctly.
|
||||
_tenant_id: str = tenant_id
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: run_alembic_migrations(_tenant_id)
|
||||
)
|
||||
except Exception:
|
||||
# The tenant was already dequeued from the pool — roll it back so
|
||||
# it doesn't end up orphaned (schema exists, but not assigned to anyone).
|
||||
logger.exception(
|
||||
f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back"
|
||||
)
|
||||
try:
|
||||
await rollback_tenant_provisioning(_tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}")
|
||||
raise
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
|
||||
@@ -100,6 +100,7 @@ def get_model_app() -> FastAPI:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -20,6 +20,7 @@ from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
@@ -65,6 +66,7 @@ if SENTRY_DSN:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -515,7 +517,8 @@ def reset_tenant_id(
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(
|
||||
sender: Any, **kwargs: Any # noqa: ARG001
|
||||
sender: Any, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
@@ -317,7 +317,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.hooks",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -362,19 +361,6 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if HOOKS_AVAILABLE:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
|
||||
@@ -9,6 +9,7 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
@@ -137,6 +138,7 @@ def _docfetching_task(
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -319,6 +319,11 @@ def monitor_indexing_attempt_progress(
|
||||
)
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
total_batches: int | str = (
|
||||
coordination_status.total_batches
|
||||
if coordination_status.total_batches is not None
|
||||
else "?"
|
||||
)
|
||||
if coordination_status.found:
|
||||
task_logger.info(
|
||||
f"Indexing attempt progress: "
|
||||
@@ -326,7 +331,7 @@ def monitor_indexing_attempt_progress(
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id} "
|
||||
f"completed_batches={coordination_status.completed_batches} "
|
||||
f"total_batches={coordination_status.total_batches or '?'} "
|
||||
f"total_batches={total_batches} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
f"elapsed={(current_db_time - attempt.time_created).seconds}"
|
||||
@@ -410,7 +415,7 @@ def check_indexing_completion(
|
||||
logger.info(
|
||||
f"Indexing status: "
|
||||
f"indexing_completed={indexing_completed} "
|
||||
f"batches_processed={batches_processed}/{batches_total or '?'} "
|
||||
f"batches_processed={batches_processed}/{batches_total if batches_total is not None else '?'} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_chunks={coordination_status.total_chunks} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -148,3 +159,114 @@ class ChatStateContainer:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,40 +1,19 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
|
||||
Tags every packet with ``model_index`` and places it on ``merged_queue``
|
||||
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
|
||||
|
||||
Args:
|
||||
merged_queue: Shared queue owned by ``_run_models``.
|
||||
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
|
||||
drain_done: Optional event set by ``_run_models`` when the drain loop
|
||||
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
|
||||
immediately so worker threads can exit fast.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merged_queue: Queue[tuple[int, Packet | Exception | object]],
|
||||
model_idx: int = 0,
|
||||
drain_done: threading.Event | None = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
self._drain_done = drain_done
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
if self._drain_done is not None and self._drain_done.is_set():
|
||||
return
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
placement=base.model_copy(update={"model_index": self._model_idx}),
|
||||
obj=packet.obj,
|
||||
)
|
||||
self._merged_queue.put((self._model_idx, tagged))
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1079,7 +1079,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -212,6 +212,7 @@ class DocumentSource(str, Enum):
|
||||
PRODUCTBOARD = "productboard"
|
||||
FILE = "file"
|
||||
CODA = "coda"
|
||||
CANVAS = "canvas"
|
||||
NOTION = "notion"
|
||||
ZULIP = "zulip"
|
||||
LINEAR = "linear"
|
||||
@@ -672,6 +673,7 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
DocumentSource.SLAB: "slab data",
|
||||
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
|
||||
DocumentSource.FILE: "files",
|
||||
DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements",
|
||||
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
|
||||
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
|
||||
project management, and collaboration tools into a single, customizable platform",
|
||||
|
||||
32
backend/onyx/connectors/canvas/access.py
Normal file
32
backend/onyx/connectors/canvas/access.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Permissioning / AccessControl logic for Canvas courses.
|
||||
|
||||
CE stub — returns None (no permissions). The EE implementation is loaded
|
||||
at runtime via ``fetch_versioned_implementation``.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
def get_course_permissions(
|
||||
canvas_client: CanvasApiClient,
|
||||
course_id: int,
|
||||
) -> ExternalAccess | None:
|
||||
if not global_version.is_ee_version():
|
||||
return None
|
||||
|
||||
ee_get_course_permissions = cast(
|
||||
Callable[[CanvasApiClient, int], ExternalAccess | None],
|
||||
fetch_versioned_implementation(
|
||||
"onyx.external_permissions.canvas.access",
|
||||
"get_course_permissions",
|
||||
),
|
||||
)
|
||||
|
||||
return ee_get_course_permissions(canvas_client, course_id)
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -190,3 +191,22 @@ class CanvasApiClient:
|
||||
if clean_endpoint:
|
||||
final_url += "/" + clean_endpoint
|
||||
return final_url
|
||||
|
||||
def paginate(
|
||||
self,
|
||||
endpoint: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Iterator[list[Any]]:
|
||||
"""Yield each page of results, following Link-header pagination.
|
||||
|
||||
Makes the first request with endpoint + params, then follows
|
||||
next_url from Link headers for subsequent pages.
|
||||
"""
|
||||
response, next_url = self.get(endpoint, params=params)
|
||||
while True:
|
||||
if not response:
|
||||
break
|
||||
yield response
|
||||
if not next_url:
|
||||
break
|
||||
response, next_url = self.get(full_url=next_url)
|
||||
|
||||
@@ -1,17 +1,82 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import NoReturn
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.canvas.access import get_course_permissions
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
|
||||
"""Map Canvas API errors to connector framework exceptions."""
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Canvas API token is invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Canvas API token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
|
||||
)
|
||||
elif e.status_code >= 500:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Canvas API error (status={e.status_code}): {e}"
|
||||
)
|
||||
|
||||
|
||||
class CanvasCourse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
course_code: str
|
||||
created_at: str
|
||||
workflow_state: str
|
||||
name: str | None = None
|
||||
course_code: str | None = None
|
||||
created_at: str | None = None
|
||||
workflow_state: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
name=payload.get("name"),
|
||||
course_code=payload.get("course_code"),
|
||||
created_at=payload.get("created_at"),
|
||||
workflow_state=payload.get("workflow_state"),
|
||||
)
|
||||
|
||||
|
||||
class CanvasPage(BaseModel):
|
||||
@@ -19,10 +84,22 @@ class CanvasPage(BaseModel):
|
||||
url: str
|
||||
title: str
|
||||
body: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage":
|
||||
return cls(
|
||||
page_id=payload["page_id"],
|
||||
url=payload["url"],
|
||||
title=payload["title"],
|
||||
body=payload.get("body"),
|
||||
created_at=payload.get("created_at"),
|
||||
updated_at=payload.get("updated_at"),
|
||||
course_id=course_id,
|
||||
)
|
||||
|
||||
|
||||
class CanvasAssignment(BaseModel):
|
||||
id: int
|
||||
@@ -30,10 +107,23 @@ class CanvasAssignment(BaseModel):
|
||||
description: str | None = None
|
||||
html_url: str
|
||||
course_id: int
|
||||
created_at: str
|
||||
updated_at: str
|
||||
created_at: str | None = None
|
||||
updated_at: str | None = None
|
||||
due_at: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
name=payload["name"],
|
||||
description=payload.get("description"),
|
||||
html_url=payload["html_url"],
|
||||
course_id=course_id,
|
||||
created_at=payload.get("created_at"),
|
||||
updated_at=payload.get("updated_at"),
|
||||
due_at=payload.get("due_at"),
|
||||
)
|
||||
|
||||
|
||||
class CanvasAnnouncement(BaseModel):
|
||||
id: int
|
||||
@@ -43,6 +133,17 @@ class CanvasAnnouncement(BaseModel):
|
||||
posted_at: str | None = None
|
||||
course_id: int
|
||||
|
||||
@classmethod
|
||||
def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement":
|
||||
return cls(
|
||||
id=payload["id"],
|
||||
title=payload["title"],
|
||||
message=payload.get("message"),
|
||||
html_url=payload["html_url"],
|
||||
posted_at=payload.get("posted_at"),
|
||||
course_id=course_id,
|
||||
)
|
||||
|
||||
|
||||
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
|
||||
|
||||
@@ -72,3 +173,286 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
|
||||
self.current_course_index += 1
|
||||
self.stage = "pages"
|
||||
self.next_url = None
|
||||
|
||||
|
||||
class CanvasConnector(
|
||||
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
canvas_base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1")
|
||||
self.batch_size = batch_size
|
||||
self._canvas_client: CanvasApiClient | None = None
|
||||
self._course_permissions_cache: dict[int, ExternalAccess | None] = {}
|
||||
|
||||
@property
|
||||
def canvas_client(self) -> CanvasApiClient:
|
||||
if self._canvas_client is None:
|
||||
raise ConnectorMissingCredentialError("Canvas")
|
||||
return self._canvas_client
|
||||
|
||||
def _get_course_permissions(self, course_id: int) -> ExternalAccess | None:
|
||||
"""Get course permissions with caching."""
|
||||
if course_id not in self._course_permissions_cache:
|
||||
self._course_permissions_cache[course_id] = get_course_permissions(
|
||||
canvas_client=self.canvas_client,
|
||||
course_id=course_id,
|
||||
)
|
||||
return self._course_permissions_cache[course_id]
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_courses(self) -> list[CanvasCourse]:
|
||||
"""Fetch all courses accessible to the authenticated user."""
|
||||
logger.debug("Fetching Canvas courses")
|
||||
|
||||
courses: list[CanvasCourse] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
"courses", params={"per_page": "100", "state[]": "available"}
|
||||
):
|
||||
courses.extend(CanvasCourse.from_api(c) for c in page)
|
||||
return courses
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_pages(self, course_id: int) -> list[CanvasPage]:
|
||||
"""Fetch all pages for a given course."""
|
||||
logger.debug(f"Fetching pages for course {course_id}")
|
||||
|
||||
pages: list[CanvasPage] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
f"courses/{course_id}/pages",
|
||||
params={"per_page": "100", "include[]": "body", "published": "true"},
|
||||
):
|
||||
pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page)
|
||||
return pages
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_assignments(self, course_id: int) -> list[CanvasAssignment]:
|
||||
"""Fetch all assignments for a given course."""
|
||||
logger.debug(f"Fetching assignments for course {course_id}")
|
||||
|
||||
assignments: list[CanvasAssignment] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
f"courses/{course_id}/assignments",
|
||||
params={"per_page": "100", "published": "true"},
|
||||
):
|
||||
assignments.extend(
|
||||
CanvasAssignment.from_api(a, course_id=course_id) for a in page
|
||||
)
|
||||
return assignments
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]:
|
||||
"""Fetch all announcements for a given course."""
|
||||
logger.debug(f"Fetching announcements for course {course_id}")
|
||||
|
||||
announcements: list[CanvasAnnouncement] = []
|
||||
for page in self.canvas_client.paginate(
|
||||
"announcements",
|
||||
params={
|
||||
"per_page": "100",
|
||||
"context_codes[]": f"course_{course_id}",
|
||||
"active_only": "true",
|
||||
},
|
||||
):
|
||||
announcements.extend(
|
||||
CanvasAnnouncement.from_api(a, course_id=course_id) for a in page
|
||||
)
|
||||
return announcements
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
link: str,
|
||||
text: str,
|
||||
semantic_identifier: str,
|
||||
doc_updated_at: datetime | None,
|
||||
course_id: int,
|
||||
doc_type: str,
|
||||
) -> Document:
|
||||
"""Build a Document with standard Canvas fields."""
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=cast(
|
||||
list[TextSection | ImageSection],
|
||||
[TextSection(link=link, text=text)],
|
||||
),
|
||||
source=DocumentSource.CANVAS,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=doc_updated_at,
|
||||
metadata={"course_id": str(course_id), "type": doc_type},
|
||||
)
|
||||
|
||||
def _convert_page_to_document(self, page: CanvasPage) -> Document:
|
||||
"""Convert a Canvas page to a Document."""
|
||||
link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}"
|
||||
|
||||
text_parts = [page.title]
|
||||
body_text = parse_html_page_basic(page.body) if page.body else ""
|
||||
if body_text:
|
||||
text_parts.append(body_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
if page.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
|
||||
link=link,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=page.title or f"Page {page.page_id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=page.course_id,
|
||||
doc_type="page",
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document:
|
||||
"""Convert a Canvas assignment to a Document."""
|
||||
text_parts = [assignment.name]
|
||||
desc_text = (
|
||||
parse_html_page_basic(assignment.description)
|
||||
if assignment.description
|
||||
else ""
|
||||
)
|
||||
if desc_text:
|
||||
text_parts.append(desc_text)
|
||||
if assignment.due_at:
|
||||
due_dt = datetime.fromisoformat(
|
||||
assignment.due_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
assignment.updated_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if assignment.updated_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}",
|
||||
link=assignment.html_url,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=assignment.name or f"Assignment {assignment.id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=assignment.course_id,
|
||||
doc_type="assignment",
|
||||
)
|
||||
return document
|
||||
|
||||
def _convert_announcement_to_document(
|
||||
self, announcement: CanvasAnnouncement
|
||||
) -> Document:
|
||||
"""Convert a Canvas announcement to a Document."""
|
||||
text_parts = [announcement.title]
|
||||
msg_text = (
|
||||
parse_html_page_basic(announcement.message) if announcement.message else ""
|
||||
)
|
||||
if msg_text:
|
||||
text_parts.append(msg_text)
|
||||
|
||||
doc_updated_at = (
|
||||
datetime.fromisoformat(
|
||||
announcement.posted_at.replace("Z", "+00:00")
|
||||
).astimezone(timezone.utc)
|
||||
if announcement.posted_at
|
||||
else None
|
||||
)
|
||||
|
||||
document = self._build_document(
|
||||
doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}",
|
||||
link=announcement.html_url,
|
||||
text="\n\n".join(text_parts),
|
||||
semantic_identifier=announcement.title or f"Announcement {announcement.id}",
|
||||
doc_updated_at=doc_updated_at,
|
||||
course_id=announcement.course_id,
|
||||
doc_type="announcement",
|
||||
)
|
||||
return document
|
||||
|
||||
@override
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load and validate Canvas credentials."""
|
||||
access_token = credentials.get("canvas_access_token")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("Canvas")
|
||||
|
||||
try:
|
||||
client = CanvasApiClient(
|
||||
bearer_token=access_token,
|
||||
canvas_base_url=self.canvas_base_url,
|
||||
)
|
||||
client.get("courses", params={"per_page": "1"})
|
||||
except ValueError as e:
|
||||
raise ConnectorValidationError(f"Invalid Canvas base URL: {e}")
|
||||
except OnyxError as e:
|
||||
_handle_canvas_api_error(e)
|
||||
|
||||
self._canvas_client = client
|
||||
return None
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Canvas connector settings by testing API access."""
|
||||
try:
|
||||
self.canvas_client.get("courses", params={"per_page": "1"})
|
||||
logger.info("Canvas connector settings validated successfully")
|
||||
except OnyxError as e:
|
||||
_handle_canvas_api_error(e)
|
||||
except ConnectorMissingCredentialError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Canvas settings validation: {exc}"
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CanvasConnectorCheckpoint,
|
||||
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> CanvasConnectorCheckpoint:
|
||||
# TODO(benwu408): implemented in PR3 (checkpoint)
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
# TODO(benwu408): implemented in PR4 (perm sync)
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -11,11 +11,13 @@ from discord import Client
|
||||
from discord.channel import TextChannel
|
||||
from discord.channel import Thread
|
||||
from discord.enums import MessageType
|
||||
from discord.errors import LoginFailure
|
||||
from discord.flags import Intents
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import CredentialInvalidError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -209,8 +211,19 @@ def _manage_async_retrieval(
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
asyncio.create_task(discord_client.start(token))
|
||||
await discord_client.wait_until_ready()
|
||||
start_task = asyncio.create_task(discord_client.start(token))
|
||||
ready_task = asyncio.create_task(discord_client.wait_until_ready())
|
||||
|
||||
done, _ = await asyncio.wait(
|
||||
{start_task, ready_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# start() runs indefinitely once connected, so it only lands
|
||||
# in `done` when login/connection failed — propagate the error.
|
||||
if start_task in done:
|
||||
ready_task.cancel()
|
||||
start_task.result()
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=discord_client,
|
||||
@@ -276,6 +289,19 @@ class DiscordConnector(PollConnector, LoadConnector):
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
client = Client(intents=Intents.default())
|
||||
try:
|
||||
loop.run_until_complete(client.login(self.discord_bot_token))
|
||||
except LoginFailure as e:
|
||||
raise CredentialInvalidError(f"Invalid Discord bot token: {e}")
|
||||
finally:
|
||||
loop.run_until_complete(client.close())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
|
||||
@@ -72,6 +72,10 @@ CONNECTOR_CLASS_MAP = {
|
||||
module_path="onyx.connectors.coda.connector",
|
||||
class_name="CodaConnector",
|
||||
),
|
||||
DocumentSource.CANVAS: ConnectorMapping(
|
||||
module_path="onyx.connectors.canvas.connector",
|
||||
class_name="CanvasConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -132,32 +131,47 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
chat_sessions = list(result.scalars().all())
|
||||
|
||||
return list(chat_sessions)
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
@@ -617,92 +631,6 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child "
|
||||
f"of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -925,8 +853,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
def search_for_document_ids(
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
|
||||
) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
|
||||
@@ -60,8 +60,7 @@ class OpenSearchSearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
RANDOM = "random"
|
||||
ID_RETRIEVAL = "id_retrieval"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
DOC_ID_RETRIEVAL = "doc_id_retrieval"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
|
||||
@@ -928,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.ID_RETRIEVAL,
|
||||
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -7,10 +6,7 @@ from shared_configs.configs import MULTI_TENANT
|
||||
def require_hook_enabled() -> None:
|
||||
"""FastAPI dependency that gates all hook management endpoints.
|
||||
|
||||
Hooks are only available in single-tenant / self-hosted deployments with
|
||||
HOOK_ENABLED=true explicitly set. Two layers of protection:
|
||||
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
|
||||
2. HOOK_ENABLED flag — explicit opt-in by the operator
|
||||
Hooks are only available in single-tenant / self-hosted EE deployments.
|
||||
|
||||
Use as: Depends(require_hook_enabled)
|
||||
"""
|
||||
@@ -19,8 +15,3 @@ def require_hook_enabled() -> None:
|
||||
OnyxErrorCode.SINGLE_TENANT_ONLY,
|
||||
"Hooks are not available in multi-tenant deployments",
|
||||
)
|
||||
if not HOOK_ENABLED:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.ENV_VAR_GATED,
|
||||
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
|
||||
)
|
||||
|
||||
@@ -1,79 +1,22 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
"""CE hook executor.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
HookSkipped and HookSoftFailed are real classes kept here because
|
||||
process_message.py (CE code) uses isinstance checks against them.
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
execute_hook is the public entry point. It dispatches to _execute_hook_impl
|
||||
via fetch_versioned_implementation so that:
|
||||
- CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped()
|
||||
- EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
@@ -87,277 +30,15 @@ class HookSoftFailed:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
def _execute_hook_impl(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
db_session: Session, # noqa: ARG001
|
||||
hook_point: HookPoint, # noqa: ARG001
|
||||
payload: dict[str, Any], # noqa: ARG001
|
||||
response_type: type[T], # noqa: ARG001
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""CE no-op — hooks are not available without EE."""
|
||||
return HookSkipped()
|
||||
|
||||
|
||||
def execute_hook(
|
||||
@@ -367,25 +48,15 @@ def execute_hook(
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously.
|
||||
"""Execute the hook for the given hook point.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
Dispatches to the versioned implementation so EE gets the real executor
|
||||
and CE gets the no-op stub, without any changes at the call site.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl")
|
||||
return impl(
|
||||
db_session=db_session,
|
||||
hook_point=hook_point,
|
||||
payload=payload,
|
||||
response_type=response_type,
|
||||
)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -8,24 +8,6 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -77,7 +77,6 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -439,6 +438,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
release=__version__,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
@@ -454,7 +454,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -62,8 +60,6 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -85,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -575,46 +570,6 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -705,30 +660,6 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -2,25 +2,11 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to. ``0`` for single-model
|
||||
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
|
||||
for pre-LLM setup packets (e.g. message ID info) that are yielded
|
||||
before any Emitter runs.
|
||||
"""
|
||||
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
@@ -38,6 +37,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -98,7 +98,7 @@ def fetch_settings(
|
||||
needs_reindexing=needs_reindexing,
|
||||
onyx_craft_enabled=onyx_craft_enabled_for_user,
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
hooks_enabled=not MULTI_TENANT,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
|
||||
@@ -116,7 +116,7 @@ class UserSettings(Settings):
|
||||
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
|
||||
# document sets are unavailable.
|
||||
vector_db_enabled: bool = True
|
||||
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
# True when hooks are available: single-tenant EE deployments only.
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -709,6 +708,7 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
emitter_queue: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(merged_queue=emitter_queue)
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {emitter_queue.qsize()}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import csv
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
@@ -12,6 +11,7 @@ import requests
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
# Use a discard emitter if none provided (packets go nowhere)
|
||||
# Use default emitter if none provided
|
||||
if emitter is None:
|
||||
emitter = Emitter(merged_queue=queue.Queue())
|
||||
emitter = get_default_emitter()
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
dynamic_schema_info=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -27,6 +28,9 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
|
||||
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
|
||||
MAX_REQUEST_ATTEMPTS = 5
|
||||
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
QUESTION_TIMEOUT_SECONDS = 300
|
||||
QUESTION_RETRY_PAUSE_SECONDS = 30
|
||||
MAX_QUESTION_ATTEMPTS = 3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -109,6 +113,27 @@ def normalize_api_base(api_base: str) -> str:
|
||||
return f"{normalized}/api"
|
||||
|
||||
|
||||
def load_completed_question_ids(output_file: Path) -> set[str]:
|
||||
if not output_file.exists():
|
||||
return set()
|
||||
|
||||
completed_ids: set[str] = set()
|
||||
with output_file.open("r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
question_id = record.get("question_id")
|
||||
if isinstance(question_id, str) and question_id:
|
||||
completed_ids.add(question_id)
|
||||
|
||||
return completed_ids
|
||||
|
||||
|
||||
def load_questions(questions_file: Path) -> list[QuestionRecord]:
|
||||
if not questions_file.exists():
|
||||
raise FileNotFoundError(f"Questions file not found: {questions_file}")
|
||||
@@ -348,6 +373,7 @@ async def generate_answers(
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
parallelism: int,
|
||||
skipped: int,
|
||||
) -> None:
|
||||
if parallelism < 1:
|
||||
raise ValueError("`--parallelism` must be at least 1.")
|
||||
@@ -382,58 +408,178 @@ async def generate_answers(
|
||||
write_lock = asyncio.Lock()
|
||||
completed = 0
|
||||
successful = 0
|
||||
stuck_count = 0
|
||||
failed_questions: list[FailedQuestionRecord] = []
|
||||
total = len(questions)
|
||||
remaining_count = len(questions)
|
||||
overall_total = remaining_count + skipped
|
||||
question_durations: list[float] = []
|
||||
run_start_time = time.monotonic()
|
||||
|
||||
def print_progress() -> None:
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
elapsed = time.monotonic() - run_start_time
|
||||
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
|
||||
|
||||
done = skipped + completed
|
||||
bar_width = 30
|
||||
filled = (
|
||||
int(bar_width * done / overall_total)
|
||||
if overall_total
|
||||
else bar_width
|
||||
)
|
||||
bar = "█" * filled + "░" * (bar_width - filled)
|
||||
pct = (done / overall_total * 100) if overall_total else 100.0
|
||||
|
||||
parts = (
|
||||
f"\r{bar} {pct:5.1f}% "
|
||||
f"[{done}/{overall_total}] "
|
||||
f"avg {avg_time:.1f}s/q "
|
||||
f"elapsed {elapsed:.0f}s "
|
||||
f"ETA {eta:.0f}s "
|
||||
f"(ok:{successful} fail:{len(failed_questions)}"
|
||||
)
|
||||
if stuck_count:
|
||||
parts += f" stuck:{stuck_count}"
|
||||
if skipped:
|
||||
parts += f" skip:{skipped}"
|
||||
parts += ")"
|
||||
|
||||
sys.stderr.write(parts)
|
||||
sys.stderr.flush()
|
||||
|
||||
print_progress()
|
||||
|
||||
async def process_question(question_record: QuestionRecord) -> None:
|
||||
nonlocal completed
|
||||
nonlocal successful
|
||||
nonlocal stuck_count
|
||||
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
|
||||
q_start = time.monotonic()
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await asyncio.wait_for(
|
||||
submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
),
|
||||
timeout=QUESTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
async with progress_lock:
|
||||
stuck_count += 1
|
||||
logger.warning(
|
||||
"Question %s timed out after %ss (attempt %s/%s, "
|
||||
"total stuck: %s) — retrying in %ss",
|
||||
question_record.question_id,
|
||||
QUESTION_TIMEOUT_SECONDS,
|
||||
attempt,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
stuck_count,
|
||||
QUESTION_RETRY_PAUSE_SECONDS,
|
||||
)
|
||||
print_progress()
|
||||
last_error = TimeoutError(
|
||||
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
|
||||
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
|
||||
)
|
||||
except Exception as exc:
|
||||
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
|
||||
continue
|
||||
except Exception as exc:
|
||||
duration = time.monotonic() - q_start
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
question_durations.append(duration)
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
return
|
||||
|
||||
duration = time.monotonic() - q_start
|
||||
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
total,
|
||||
)
|
||||
successful += 1
|
||||
question_durations.append(duration)
|
||||
print_progress()
|
||||
return
|
||||
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
# All attempts exhausted due to timeouts
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
successful += 1
|
||||
logger.info("Processed %s/%s questions", completed, total)
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(last_error),
|
||||
)
|
||||
)
|
||||
logger.error(
|
||||
"Question %s failed after %s timeout attempts (%s/%s)",
|
||||
question_record.question_id,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
|
||||
await asyncio.gather(
|
||||
*(process_question(question_record) for question_record in questions)
|
||||
)
|
||||
|
||||
# Final newline after progress bar
|
||||
sys.stderr.write("\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
total_elapsed = time.monotonic() - run_start_time
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
|
||||
resume_suffix = (
|
||||
f" — {skipped} previously completed, "
|
||||
f"{skipped + successful}/{overall_total} overall"
|
||||
if skipped
|
||||
else ""
|
||||
)
|
||||
logger.info(
|
||||
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
|
||||
successful,
|
||||
remaining_count,
|
||||
total_elapsed,
|
||||
avg_time,
|
||||
stuck_suffix,
|
||||
resume_suffix,
|
||||
)
|
||||
|
||||
if failed_questions:
|
||||
logger.warning(
|
||||
"Completed with %s failed questions and %s successful questions.",
|
||||
"%s questions failed:",
|
||||
len(failed_questions),
|
||||
successful,
|
||||
)
|
||||
for failed_question in failed_questions:
|
||||
logger.warning(
|
||||
@@ -453,7 +599,30 @@ def main() -> None:
|
||||
raise ValueError("`--max-questions` must be at least 1 when provided.")
|
||||
questions = questions[: args.max_questions]
|
||||
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
completed_ids = load_completed_question_ids(args.output_file)
|
||||
logger.info(
|
||||
"Found %s already-answered question IDs in %s",
|
||||
len(completed_ids),
|
||||
args.output_file,
|
||||
)
|
||||
total_before_filter = len(questions)
|
||||
questions = [q for q in questions if q.question_id not in completed_ids]
|
||||
skipped = total_before_filter - len(questions)
|
||||
|
||||
if skipped:
|
||||
logger.info(
|
||||
"Resuming: %s/%s already answered, %s remaining",
|
||||
skipped,
|
||||
total_before_filter,
|
||||
len(questions),
|
||||
)
|
||||
else:
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
|
||||
if not questions:
|
||||
logger.info("All questions already answered. Nothing to do.")
|
||||
return
|
||||
|
||||
logger.info("Writing answers to %s", args.output_file)
|
||||
|
||||
asyncio.run(
|
||||
@@ -463,6 +632,7 @@ def main() -> None:
|
||||
api_base=api_base,
|
||||
api_key=args.api_key,
|
||||
parallelism=args.parallelism,
|
||||
skipped=skipped,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -27,13 +27,11 @@ def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
model_index: int | None = 0,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
model_index=model_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ This test:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -21,7 +20,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -138,7 +137,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -201,7 +200,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -276,7 +275,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -351,7 +350,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -459,7 +458,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -542,7 +541,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
|
||||
@@ -8,7 +8,6 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -17,7 +16,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -175,7 +174,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -233,7 +232,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -285,7 +284,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -346,7 +345,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -417,7 +416,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -484,7 +483,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -537,7 +536,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
0
backend/tests/unit/ee/onyx/hooks/__init__.py
Normal file
0
backend/tests/unit/ee/onyx/hooks/__init__.py
Normal file
@@ -9,11 +9,11 @@ import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.hooks.executor import _execute_hook_impl as execute_hook
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.query_processing import QueryProcessingResponse
|
||||
@@ -118,28 +118,30 @@ def db_session() -> MagicMock:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hooks_available,hook",
|
||||
"multi_tenant,hook",
|
||||
[
|
||||
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
|
||||
pytest.param(False, None, id="hooks_not_available"),
|
||||
pytest.param(True, None, id="hook_not_found"),
|
||||
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
|
||||
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
|
||||
# MULTI_TENANT=True exits before the DB lookup — hook is irrelevant.
|
||||
pytest.param(True, None, id="multi_tenant"),
|
||||
pytest.param(False, None, id="hook_not_found"),
|
||||
pytest.param(False, _make_hook(is_active=False), id="hook_inactive"),
|
||||
pytest.param(False, _make_hook(endpoint_url=None), id="no_endpoint_url"),
|
||||
],
|
||||
)
|
||||
def test_early_exit_returns_skipped_with_no_db_writes(
|
||||
db_session: MagicMock,
|
||||
hooks_available: bool,
|
||||
multi_tenant: bool,
|
||||
hook: MagicMock | None,
|
||||
) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", multi_tenant),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
@@ -164,14 +166,16 @@ def test_success_returns_validated_model_and_sets_reachable(
|
||||
hook = _make_hook()
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
@@ -195,14 +199,14 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
|
||||
hook = _make_hook(is_reachable=True)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, response=_make_response())
|
||||
@@ -224,14 +228,16 @@ def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
@@ -258,14 +264,16 @@ def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(
|
||||
@@ -384,14 +392,14 @@ def test_http_failure_paths(
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=exception)
|
||||
@@ -443,14 +451,14 @@ def test_authorization_header(
|
||||
hook = _make_hook(api_key=api_key)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit"),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit"),
|
||||
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
mock_client = _setup_client(mock_client_cls, response=_make_response())
|
||||
@@ -489,13 +497,13 @@ def test_persist_session_failure_is_swallowed(
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_session_with_current_tenant",
|
||||
"ee.onyx.hooks.executor.get_session_with_current_tenant",
|
||||
side_effect=RuntimeError("DB unavailable"),
|
||||
),
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
@@ -556,14 +564,16 @@ def test_response_validation_failure_respects_fail_strategy(
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
# Response payload is missing required_field → ValidationError
|
||||
@@ -619,13 +629,13 @@ def test_unexpected_exception_in_inner_respects_fail_strategy(
|
||||
hook = _make_hook(fail_strategy=fail_strategy)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch(
|
||||
"onyx.hooks.executor._execute_hook_inner",
|
||||
"ee.onyx.hooks.executor._execute_hook_inner",
|
||||
side_effect=ValueError("unexpected bug"),
|
||||
),
|
||||
):
|
||||
@@ -658,17 +668,19 @@ def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> Non
|
||||
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
with (
|
||||
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
|
||||
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
|
||||
patch(
|
||||
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
|
||||
return_value=hook,
|
||||
),
|
||||
patch("onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.hooks.executor.update_hook__no_commit",
|
||||
"ee.onyx.hooks.executor.update_hook__no_commit",
|
||||
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
|
||||
),
|
||||
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
|
||||
patch(
|
||||
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
|
||||
) as mock_log,
|
||||
patch("httpx.Client") as mock_client_cls,
|
||||
):
|
||||
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
|
||||
0
backend/tests/unit/ee/onyx/server/__init__.py
Normal file
0
backend/tests/unit/ee/onyx/server/__init__.py
Normal file
@@ -1,4 +1,4 @@
|
||||
"""Unit tests for onyx.server.features.hooks.api helpers.
|
||||
"""Unit tests for ee.onyx.server.features.hooks.api helpers.
|
||||
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
@@ -16,13 +16,13 @@ from unittest.mock import patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from ee.onyx.server.features.hooks.api import _check_ssrf_safety
|
||||
from ee.onyx.server.features.hooks.api import _raise_for_validation_failure
|
||||
from ee.onyx.server.features.hooks.api import _validate_endpoint
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.server.features.hooks.api import _check_ssrf_safety
|
||||
from onyx.server.features.hooks.api import _raise_for_validation_failure
|
||||
from onyx.server.features.hooks.api import _validate_endpoint
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -117,28 +117,28 @@ class TestCheckSsrfSafety:
|
||||
class TestValidateEndpoint:
|
||||
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
|
||||
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
|
||||
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
|
||||
with patch("ee.onyx.server.features.hooks.api._check_ssrf_safety"):
|
||||
return _validate_endpoint(
|
||||
endpoint_url=_URL,
|
||||
api_key=api_key,
|
||||
timeout_seconds=_TIMEOUT,
|
||||
)
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(200)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(500)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize("status_code", [401, 403])
|
||||
def test_401_403_returns_auth_failed(
|
||||
self, mock_client_cls: MagicMock, status_code: int
|
||||
@@ -150,21 +150,21 @@ class TestValidateEndpoint:
|
||||
assert result.status == HookValidateStatus.auth_failed
|
||||
assert str(status_code) in (result.error_message or "")
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
|
||||
_mock_response(422)
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[
|
||||
@@ -179,7 +179,7 @@ class TestValidateEndpoint:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_error_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
@@ -189,7 +189,7 @@ class TestValidateEndpoint:
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_arbitrary_exception_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
@@ -198,7 +198,7 @@ class TestValidateEndpoint:
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
@@ -206,7 +206,7 @@ class TestValidateEndpoint:
|
||||
_, kwargs = mock_post.call_args
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = _mock_response(200)
|
||||
@@ -1,173 +0,0 @@
|
||||
"""Unit tests for the Emitter class.
|
||||
|
||||
All tests use the streaming mode (merged_queue required). Emitter has a single
|
||||
code path — no standalone bus.
|
||||
"""
|
||||
|
||||
import queue
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _placement(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _packet(
|
||||
turn_index: int = 0,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> Packet:
|
||||
"""Build a minimal valid packet with an OverallStop payload."""
|
||||
return Packet(
|
||||
placement=_placement(turn_index, tab_index, sub_turn_index),
|
||||
obj=OverallStop(stop_reason="test"),
|
||||
)
|
||||
|
||||
|
||||
def _make_emitter(model_idx: int = 0) -> tuple["Emitter", "queue.Queue"]:
|
||||
"""Return (emitter, queue) wired together."""
|
||||
mq: queue.Queue = queue.Queue()
|
||||
return Emitter(merged_queue=mq, model_idx=model_idx), mq
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueRouting:
|
||||
def test_emit_lands_on_merged_queue(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet())
|
||||
assert not mq.empty()
|
||||
|
||||
def test_queue_item_is_tuple_of_key_and_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
item = mq.get_nowait()
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
|
||||
def test_multiple_packets_delivered_fifo(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
p1 = _packet(turn_index=0)
|
||||
p2 = _packet(turn_index=1)
|
||||
emitter.emit(p1)
|
||||
emitter.emit(p2)
|
||||
_, t1 = mq.get_nowait()
|
||||
_, t2 = mq.get_nowait()
|
||||
assert t1.placement.turn_index == 0
|
||||
assert t2.placement.turn_index == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_index tagging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterModelIndexTagging:
|
||||
def test_n1_default_model_idx_tags_model_index_zero(self) -> None:
|
||||
"""N=1: default model_idx=0, so packet gets model_index=0."""
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 0
|
||||
|
||||
def test_model_idx_one_tags_packet(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=1)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 1
|
||||
|
||||
def test_model_idx_two_tags_packet(self) -> None:
|
||||
"""Boundary: third model in a 3-model run."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
_key, tagged = mq.get_nowait()
|
||||
assert tagged.placement.model_index == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Queue key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterQueueKey:
|
||||
def test_key_equals_model_idx(self) -> None:
|
||||
"""Drain loop uses the key to route packets; it must match model_idx."""
|
||||
emitter, mq = _make_emitter(model_idx=2)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 2
|
||||
|
||||
def test_n1_key_is_zero(self) -> None:
|
||||
emitter, mq = _make_emitter(model_idx=0)
|
||||
emitter.emit(_packet())
|
||||
key, _ = mq.get_nowait()
|
||||
assert key == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placement field preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmitterPlacementPreservation:
|
||||
def test_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(turn_index=5))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.turn_index == 5
|
||||
|
||||
def test_tab_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(tab_index=3))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.tab_index == 3
|
||||
|
||||
def test_sub_turn_index_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=2))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index == 2
|
||||
|
||||
def test_sub_turn_index_none_is_preserved(self) -> None:
|
||||
emitter, mq = _make_emitter()
|
||||
emitter.emit(_packet(sub_turn_index=None))
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.placement.sub_turn_index is None
|
||||
|
||||
def test_packet_obj_is_not_modified(self) -> None:
|
||||
"""The payload object must survive tagging untouched."""
|
||||
emitter, mq = _make_emitter()
|
||||
original_obj = OverallStop(stop_reason="sentinel")
|
||||
pkt = Packet(placement=_placement(), obj=original_obj)
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert tagged.obj is original_obj
|
||||
|
||||
def test_different_obj_types_are_handled(self) -> None:
|
||||
"""Any valid PacketObj type passes through correctly."""
|
||||
emitter, mq = _make_emitter()
|
||||
pkt = Packet(placement=_placement(), obj=ReasoningStart())
|
||||
emitter.emit(pkt)
|
||||
_, tagged = mq.get_nowait()
|
||||
assert isinstance(tagged.obj, ReasoningStart)
|
||||
@@ -1,676 +0,0 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_ee_version() -> Generator[None, None, None]:
|
||||
"""Reset EE global state after each test.
|
||||
|
||||
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
|
||||
(via the celery client import chain). Without this fixture, the EE flag stays
|
||||
True for the rest of the session and breaks unrelated tests that mock Confluence
|
||||
or other connectors and assume EE is disabled.
|
||||
"""
|
||||
original = global_version._is_ee
|
||||
yield
|
||||
global_version._is_ee = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
|
||||
"""Return the first item yielded by handle_multi_model_stream."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
return next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_yields_error(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_four_overrides_yields_error(self) -> None:
|
||||
"""4 overrides exceeds maximum — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_zero_overrides_yields_error(self) -> None:
|
||||
"""Empty override list yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_deep_research_yields_error(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
|
||||
req = _make_request(deep_research=True)
|
||||
result = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "not supported" in result.error
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override yields error, 2 overrides passes validation."""
|
||||
req = _make_request()
|
||||
# 1 override must yield a StreamingError
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(
|
||||
result, StreamingError
|
||||
), "1 override should yield StreamingError"
|
||||
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
|
||||
# missing session, that's OK — validation itself passed)
|
||||
try:
|
||||
result2 = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
if isinstance(result2, StreamingError) and "2-3" in result2.error:
|
||||
pytest.fail(
|
||||
f"2 overrides should pass validation, got StreamingError: {result2.error}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Any non-validation error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
_RUN_MODELS_PATCHES = [
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch("onyx.chat.process_message.get_llm_token_counter", return_value=lambda _: 0),
|
||||
]
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
|
||||
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 0
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
emitter = kwargs["emitter"]
|
||||
# _model_idx is always int (0 for N=1, 0/1/2… for N>1)
|
||||
if emitter._model_idx == 0:
|
||||
raise RuntimeError("model 0 failed")
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_completion_handle_called_on_disconnect(self) -> None:
|
||||
"""llm_loop_completion_handle must still be called even when user disconnects.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers worker self-completion.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator. The finally block sets drain_done (signalling
|
||||
emitters to stop blocking) and calls executor.shutdown(wait=False) so the
|
||||
server thread is never blocked. Worker threads detect drain_done.is_set()
|
||||
after run_llm_loop completes and self-persist the result via
|
||||
llm_loop_completion_handle using their own DB session.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
# Signals the worker to unblock from run_llm_loop after gen.close() returns.
|
||||
# This guarantees drain_done is set BEFORE the worker returns from run_llm_loop,
|
||||
# so the self-completion path (drain_done.is_set() check) is always taken.
|
||||
disconnect_received = threading.Event()
|
||||
# Set by the llm_loop_completion_handle mock when called.
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_then_complete(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give the drain loop a yield point), then block
|
||||
until the main thread signals that gen.close() has been called. This
|
||||
ensures drain_done is set before we return so model_succeeded is checked
|
||||
against a set drain_done — no race condition.
|
||||
"""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
disconnect_received.wait(timeout=5)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_complete,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
# cast to Generator so .close() is available; _run_models returns
|
||||
# AnswerStream (= Iterator) but the actual object is always a generator.
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
# Advance to the first yielded packet — generator suspends at `yield item`.
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# GeneratorExit is thrown at the `yield item` suspension point.
|
||||
gen.close()
|
||||
# Unblock the worker now that drain_done has been set by gen.close().
|
||||
disconnect_received.set()
|
||||
|
||||
# Worker self-completes asynchronously (executor.shutdown(wait=False)).
|
||||
# Wait here, inside the patch context, so that get_session_with_current_tenant
|
||||
# and llm_loop_completion_handle mocks are still active when the worker calls them.
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "worker must self-complete via drain_done within 5 seconds"
|
||||
assert (
|
||||
mock_handle.call_count == 1
|
||||
), "completion handle must be called once for the successful model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
@@ -1,15 +1,23 @@
|
||||
"""Tests for Canvas connector — client (PR1)."""
|
||||
"""Tests for Canvas connector — client, credentials, conversion."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.canvas.client import CanvasApiClient
|
||||
from onyx.connectors.canvas.connector import CanvasConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -18,6 +26,77 @@ FAKE_BASE_URL = "https://myschool.instructure.com"
|
||||
FAKE_TOKEN = "fake-canvas-token"
|
||||
|
||||
|
||||
def _mock_course(
|
||||
course_id: int = 1,
|
||||
name: str = "Intro to CS",
|
||||
course_code: str = "CS101",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": course_id,
|
||||
"name": name,
|
||||
"course_code": course_code,
|
||||
"created_at": "2025-01-01T00:00:00Z",
|
||||
"workflow_state": "available",
|
||||
}
|
||||
|
||||
|
||||
def _build_connector(base_url: str = FAKE_BASE_URL) -> CanvasConnector:
|
||||
"""Build a connector with mocked credential validation."""
|
||||
with patch("onyx.connectors.canvas.client.rl_requests") as mock_req:
|
||||
mock_req.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = CanvasConnector(canvas_base_url=base_url)
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
return connector
|
||||
|
||||
|
||||
def _mock_page(
|
||||
page_id: int = 10,
|
||||
title: str = "Syllabus",
|
||||
updated_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"page_id": page_id,
|
||||
"url": "syllabus",
|
||||
"title": title,
|
||||
"body": "<p>Welcome to the course</p>",
|
||||
"created_at": "2025-01-15T00:00:00Z",
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
|
||||
|
||||
def _mock_assignment(
|
||||
assignment_id: int = 20,
|
||||
name: str = "Homework 1",
|
||||
course_id: int = 1,
|
||||
updated_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": assignment_id,
|
||||
"name": name,
|
||||
"description": "<p>Solve these problems</p>",
|
||||
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/assignments/{assignment_id}",
|
||||
"course_id": course_id,
|
||||
"created_at": "2025-01-20T00:00:00Z",
|
||||
"updated_at": updated_at,
|
||||
"due_at": "2025-02-01T23:59:00Z",
|
||||
}
|
||||
|
||||
|
||||
def _mock_announcement(
|
||||
announcement_id: int = 30,
|
||||
title: str = "Class Cancelled",
|
||||
course_id: int = 1,
|
||||
posted_at: str = "2025-06-01T12:00:00Z",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"id": announcement_id,
|
||||
"title": title,
|
||||
"message": "<p>No class today</p>",
|
||||
"html_url": f"{FAKE_BASE_URL}/courses/{course_id}/discussion_topics/{announcement_id}",
|
||||
"posted_at": posted_at,
|
||||
}
|
||||
|
||||
|
||||
def _mock_response(
|
||||
status_code: int = 200,
|
||||
json_data: Any = None,
|
||||
@@ -325,6 +404,57 @@ class TestGet:
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient.paginate tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPaginate:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[{"id": 1}, {"id": 2}]
|
||||
)
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert len(pages) == 1
|
||||
assert pages[0] == [{"id": 1}, {"id": 2}]
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_two_pages(self, mock_requests: MagicMock) -> None:
|
||||
next_link = f'<{FAKE_BASE_URL}/api/v1/courses?page=2>; rel="next"'
|
||||
page1 = _mock_response(json_data=[{"id": 1}], link_header=next_link)
|
||||
page2 = _mock_response(json_data=[{"id": 2}])
|
||||
mock_requests.get.side_effect = [page1, page2]
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert len(pages) == 2
|
||||
assert pages[0] == [{"id": 1}]
|
||||
assert pages[1] == [{"id": 2}]
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
client = CanvasApiClient(
|
||||
bearer_token=FAKE_TOKEN,
|
||||
canvas_base_url=FAKE_BASE_URL,
|
||||
)
|
||||
|
||||
pages = list(client.paginate("courses"))
|
||||
|
||||
assert pages == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasApiClient._parse_next_link tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -379,3 +509,368 @@ class TestParseNextLink:
|
||||
|
||||
with pytest.raises(OnyxError, match="must use https"):
|
||||
self.client._parse_next_link(header)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — credential loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLoadCredentials:
|
||||
def _assert_load_credentials_raises(
|
||||
self,
|
||||
status_code: int,
|
||||
expected_error: type[Exception],
|
||||
mock_requests: MagicMock,
|
||||
) -> None:
|
||||
"""Helper: assert load_credentials raises expected_error for a given status."""
|
||||
mock_requests.get.return_value = _mock_response(status_code, {})
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
with pytest.raises(expected_error):
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_success(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
|
||||
result = connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
|
||||
assert result is None
|
||||
assert connector._canvas_client is not None
|
||||
|
||||
def test_canvas_client_raises_without_credentials(self) -> None:
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
|
||||
with pytest.raises(ConnectorMissingCredentialError):
|
||||
_ = connector.canvas_client
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_invalid_token(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_load_credentials_raises(401, CredentialExpiredError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_load_credentials_insufficient_permissions(
|
||||
self, mock_requests: MagicMock
|
||||
) -> None:
|
||||
self._assert_load_credentials_raises(
|
||||
403, InsufficientPermissionsError, mock_requests
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — URL normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnectorUrlNormalization:
|
||||
def test_strips_api_v1_suffix(self) -> None:
|
||||
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/api/v1")
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_strips_trailing_slash(self) -> None:
|
||||
connector = _build_connector(base_url=f"{FAKE_BASE_URL}/")
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_no_change_for_clean_url(self) -> None:
|
||||
connector = _build_connector(base_url=FAKE_BASE_URL)
|
||||
|
||||
result = connector.canvas_base_url
|
||||
expected = FAKE_BASE_URL
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — document conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDocumentConversion:
|
||||
def setup_method(self) -> None:
|
||||
self.connector = _build_connector()
|
||||
|
||||
def test_convert_page_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasPage
|
||||
|
||||
page = CanvasPage(
|
||||
page_id=10,
|
||||
url="syllabus",
|
||||
title="Syllabus",
|
||||
body="<p>Welcome</p>",
|
||||
created_at="2025-01-15T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_page_to_document(page)
|
||||
|
||||
expected_id = "canvas-page-1-10"
|
||||
expected_metadata = {"course_id": "1", "type": "page"}
|
||||
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Syllabus"
|
||||
assert doc.metadata == expected_metadata
|
||||
assert doc.sections[0].link is not None
|
||||
assert f"{FAKE_BASE_URL}/courses/1/pages/syllabus" in doc.sections[0].link
|
||||
assert doc.doc_updated_at == expected_updated_at
|
||||
|
||||
def test_convert_page_without_body(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasPage
|
||||
|
||||
page = CanvasPage(
|
||||
page_id=11,
|
||||
url="empty-page",
|
||||
title="Empty Page",
|
||||
body=None,
|
||||
created_at="2025-01-15T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_page_to_document(page)
|
||||
section_text = doc.sections[0].text
|
||||
assert section_text is not None
|
||||
|
||||
assert "Empty Page" in section_text
|
||||
assert "<p>" not in section_text
|
||||
|
||||
def test_convert_assignment_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAssignment
|
||||
|
||||
assignment = CanvasAssignment(
|
||||
id=20,
|
||||
name="Homework 1",
|
||||
description="<p>Solve these</p>",
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/20",
|
||||
course_id=1,
|
||||
created_at="2025-01-20T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
due_at="2025-02-01T23:59:00Z",
|
||||
)
|
||||
|
||||
doc = self.connector._convert_assignment_to_document(assignment)
|
||||
|
||||
expected_id = "canvas-assignment-1-20"
|
||||
expected_due_text = "Due: February 01, 2025 23:59 UTC"
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Homework 1"
|
||||
assert doc.sections[0].text is not None
|
||||
assert expected_due_text in doc.sections[0].text
|
||||
|
||||
def test_convert_assignment_without_description(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAssignment
|
||||
|
||||
assignment = CanvasAssignment(
|
||||
id=21,
|
||||
name="Quiz 1",
|
||||
description=None,
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/assignments/21",
|
||||
course_id=1,
|
||||
created_at="2025-01-20T00:00:00Z",
|
||||
updated_at="2025-06-01T12:00:00Z",
|
||||
due_at=None,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_assignment_to_document(assignment)
|
||||
section_text = doc.sections[0].text
|
||||
assert section_text is not None
|
||||
|
||||
assert "Quiz 1" in section_text
|
||||
assert "Due:" not in section_text
|
||||
|
||||
def test_convert_announcement_to_document(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAnnouncement
|
||||
|
||||
announcement = CanvasAnnouncement(
|
||||
id=30,
|
||||
title="Class Cancelled",
|
||||
message="<p>No class today</p>",
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/30",
|
||||
posted_at="2025-06-01T12:00:00Z",
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_announcement_to_document(announcement)
|
||||
|
||||
expected_id = "canvas-announcement-1-30"
|
||||
expected_updated_at = datetime(2025, 6, 1, 12, 0, tzinfo=timezone.utc)
|
||||
|
||||
assert doc.id == expected_id
|
||||
assert doc.source == DocumentSource.CANVAS
|
||||
assert doc.semantic_identifier == "Class Cancelled"
|
||||
assert doc.doc_updated_at == expected_updated_at
|
||||
|
||||
def test_convert_announcement_without_posted_at(self) -> None:
|
||||
from onyx.connectors.canvas.connector import CanvasAnnouncement
|
||||
|
||||
announcement = CanvasAnnouncement(
|
||||
id=31,
|
||||
title="TBD Announcement",
|
||||
message=None,
|
||||
html_url=f"{FAKE_BASE_URL}/courses/1/discussion_topics/31",
|
||||
posted_at=None,
|
||||
course_id=1,
|
||||
)
|
||||
|
||||
doc = self.connector._convert_announcement_to_document(announcement)
|
||||
|
||||
assert doc.doc_updated_at is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CanvasConnector — validate_connector_settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateConnectorSettings:
|
||||
def _assert_validate_raises(
|
||||
self,
|
||||
status_code: int,
|
||||
expected_error: type[Exception],
|
||||
mock_requests: MagicMock,
|
||||
) -> None:
|
||||
"""Helper: assert validate_connector_settings raises expected_error."""
|
||||
success_resp = _mock_response(json_data=[_mock_course()])
|
||||
fail_resp = _mock_response(status_code, {})
|
||||
mock_requests.get.side_effect = [success_resp, fail_resp]
|
||||
connector = CanvasConnector(canvas_base_url=FAKE_BASE_URL)
|
||||
connector.load_credentials({"canvas_access_token": FAKE_TOKEN})
|
||||
with pytest.raises(expected_error):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_success(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[_mock_course()])
|
||||
connector = _build_connector()
|
||||
|
||||
connector.validate_connector_settings() # should not raise
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_expired_credential(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(401, CredentialExpiredError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_insufficient_permissions(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(403, InsufficientPermissionsError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_rate_limited(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(429, ConnectorValidationError, mock_requests)
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_validate_unexpected_error(self, mock_requests: MagicMock) -> None:
|
||||
self._assert_validate_raises(500, UnexpectedValidationError, mock_requests)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _list_* pagination tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListCourses:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_course(1), _mock_course(2, "CS201", "Data Structures")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_courses()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 1
|
||||
assert result[1].id == 2
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_courses()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListPages:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_page(10), _mock_page(11, "Notes")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_pages(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].page_id == 10
|
||||
assert result[1].page_id == 11
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_pages(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListAssignments:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_assignment(20), _mock_assignment(21, "Quiz 1")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_assignments(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 20
|
||||
assert result[1].id == 21
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_assignments(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestListAnnouncements:
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_single_page(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(
|
||||
json_data=[_mock_announcement(30), _mock_announcement(31, "Update")]
|
||||
)
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_announcements(course_id=1)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 30
|
||||
assert result[1].id == 31
|
||||
|
||||
@patch("onyx.connectors.canvas.client.rl_requests")
|
||||
def test_empty_response(self, mock_requests: MagicMock) -> None:
|
||||
mock_requests.get.return_value = _mock_response(json_data=[])
|
||||
connector = _build_connector()
|
||||
|
||||
result = connector._list_announcements(course_id=1)
|
||||
|
||||
assert result == []
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from discord.errors import LoginFailure
|
||||
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.exceptions import CredentialInvalidError
|
||||
|
||||
|
||||
def _build_connector(token: str = "fake-bot-token") -> DiscordConnector:
|
||||
connector = DiscordConnector()
|
||||
connector.load_credentials({"discord_bot_token": token})
|
||||
return connector
|
||||
|
||||
|
||||
@patch("onyx.connectors.discord.connector.Client.close", new_callable=AsyncMock)
|
||||
@patch("onyx.connectors.discord.connector.Client.login", new_callable=AsyncMock)
|
||||
def test_validate_success(
|
||||
mock_login: AsyncMock,
|
||||
mock_close: AsyncMock,
|
||||
) -> None:
|
||||
connector = _build_connector()
|
||||
connector.validate_connector_settings()
|
||||
|
||||
mock_login.assert_awaited_once_with("fake-bot-token")
|
||||
mock_close.assert_awaited_once()
|
||||
|
||||
|
||||
@patch("onyx.connectors.discord.connector.Client.close", new_callable=AsyncMock)
|
||||
@patch(
|
||||
"onyx.connectors.discord.connector.Client.login",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=LoginFailure("Improper token has been passed."),
|
||||
)
|
||||
def test_validate_invalid_token(
|
||||
mock_login: AsyncMock, # noqa: ARG001
|
||||
mock_close: AsyncMock,
|
||||
) -> None:
|
||||
connector = _build_connector(token="bad-token")
|
||||
|
||||
with pytest.raises(CredentialInvalidError, match="Invalid Discord bot token"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
mock_close.assert_awaited_once()
|
||||
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Tests for get_chat_sessions_by_user filtering behavior.
|
||||
|
||||
Verifies that failed chat sessions (those with only SYSTEM messages) are
|
||||
correctly filtered out while preserving recently created sessions, matching
|
||||
the behavior specified in PR #7233.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def _make_session(
|
||||
user_id: UUID,
|
||||
time_created: datetime | None = None,
|
||||
time_updated: datetime | None = None,
|
||||
description: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock ChatSession with the given attributes."""
|
||||
session = MagicMock(spec=ChatSession)
|
||||
session.id = uuid4()
|
||||
session.user_id = user_id
|
||||
session.time_created = time_created or datetime.now(timezone.utc)
|
||||
session.time_updated = time_updated or session.time_created
|
||||
session.description = description
|
||||
session.deleted = False
|
||||
session.onyxbot_flow = False
|
||||
session.project_id = None
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id() -> UUID:
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_time() -> datetime:
|
||||
"""A timestamp well outside the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recent_time() -> datetime:
|
||||
"""A timestamp within the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(minutes=2)
|
||||
|
||||
|
||||
class TestGetChatSessionsByUser:
|
||||
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
|
||||
|
||||
def test_filters_out_failed_sessions(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Sessions with only SYSTEM messages should be excluded."""
|
||||
valid_session = _make_session(user_id, time_created=old_time)
|
||||
failed_session = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
# First execute: returns all sessions
|
||||
# Second execute: returns only the valid session's ID (has non-system msgs)
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
valid_session,
|
||||
failed_session,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == valid_session.id
|
||||
|
||||
def test_keeps_recent_sessions_without_messages(
|
||||
self, user_id: UUID, recent_time: datetime
|
||||
) -> None:
|
||||
"""Recently created sessions should be kept even without messages."""
|
||||
recent_session = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [recent_session]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == recent_session.id
|
||||
# Should only have been called once — no second query needed
|
||||
# because the recent session is within the leeway window
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_include_failed_chats_skips_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""When include_failed_chats=True, no filtering should occur."""
|
||||
session_a = _make_session(user_id, time_created=old_time)
|
||||
session_b = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=True,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Only one DB call — no second query for message validation
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_limit_applied_after_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Limit should be applied after filtering, not before."""
|
||||
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
|
||||
valid_ids = [s.id for s in sessions[:3]]
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = sessions
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = valid_ids
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Should be the first 2 valid sessions (order preserved)
|
||||
assert result[0].id == sessions[0].id
|
||||
assert result[1].id == sessions[1].id
|
||||
|
||||
def test_mixed_recent_and_old_sessions(
|
||||
self, user_id: UUID, old_time: datetime, recent_time: datetime
|
||||
) -> None:
|
||||
"""Mix of recent and old sessions should filter correctly."""
|
||||
old_valid = _make_session(user_id, time_created=old_time)
|
||||
old_failed = _make_session(user_id, time_created=old_time)
|
||||
recent_no_msgs = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
old_valid,
|
||||
old_failed,
|
||||
recent_no_msgs,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
result_ids = {cs.id for cs in result}
|
||||
assert old_valid.id in result_ids
|
||||
assert recent_no_msgs.id in result_ids
|
||||
assert old_failed.id not in result_ids
|
||||
|
||||
def test_empty_result(self, user_id: UUID) -> None:
|
||||
"""No sessions should return empty list without errors."""
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert db_session.execute.call_count == 1
|
||||
@@ -11,30 +11,13 @@ from onyx.hooks.api_dependencies import require_hook_enabled
|
||||
|
||||
class TestRequireHookEnabled:
|
||||
def test_raises_when_multi_tenant(self) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.api_dependencies.MULTI_TENANT", True),
|
||||
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
|
||||
):
|
||||
with patch("onyx.hooks.api_dependencies.MULTI_TENANT", True):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
require_hook_enabled()
|
||||
assert exc_info.value.error_code is OnyxErrorCode.SINGLE_TENANT_ONLY
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "multi-tenant" in exc_info.value.detail
|
||||
|
||||
def test_raises_when_flag_disabled(self) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
|
||||
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", False),
|
||||
):
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
require_hook_enabled()
|
||||
assert exc_info.value.error_code is OnyxErrorCode.ENV_VAR_GATED
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "HOOK_ENABLED" in exc_info.value.detail
|
||||
|
||||
def test_passes_when_enabled_single_tenant(self) -> None:
|
||||
with (
|
||||
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
|
||||
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
|
||||
):
|
||||
def test_passes_when_single_tenant(self) -> None:
|
||||
with patch("onyx.hooks.api_dependencies.MULTI_TENANT", False):
|
||||
require_hook_enabled() # must not raise
|
||||
|
||||
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
0
backend/tests/unit/onyx/server/features/__init__.py
Normal file
@@ -1,6 +1,6 @@
|
||||
"""Tests for memory tool streaming packet emissions."""
|
||||
|
||||
import queue
|
||||
from queue import Queue
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -18,13 +18,9 @@ from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter_queue() -> queue.Queue:
|
||||
return queue.Queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def emitter(emitter_queue: queue.Queue) -> Emitter:
|
||||
return Emitter(merged_queue=emitter_queue)
|
||||
def emitter() -> Emitter:
|
||||
bus: Queue = Queue()
|
||||
return Emitter(bus)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -57,27 +53,24 @@ class TestMemoryToolEmitStart:
|
||||
def test_emit_start_emits_memory_tool_start_packet(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter_queue: queue.Queue,
|
||||
emitter: Emitter,
|
||||
placement: Placement,
|
||||
) -> None:
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
packet = emitter.bus.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolStart)
|
||||
assert packet.placement is not None
|
||||
assert packet.placement.turn_index == placement.turn_index
|
||||
assert packet.placement.tab_index == placement.tab_index
|
||||
assert packet.placement.model_index == 0 # emitter stamps model_index=0
|
||||
assert packet.placement == placement
|
||||
|
||||
def test_emit_start_with_different_placement(
|
||||
self,
|
||||
memory_tool: MemoryTool,
|
||||
emitter_queue: queue.Queue,
|
||||
emitter: Emitter,
|
||||
) -> None:
|
||||
placement = Placement(turn_index=2, tab_index=1)
|
||||
memory_tool.emit_start(placement)
|
||||
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
packet = emitter.bus.get_nowait()
|
||||
assert packet.placement.turn_index == 2
|
||||
assert packet.placement.tab_index == 1
|
||||
|
||||
@@ -88,7 +81,7 @@ class TestMemoryToolRun:
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter_queue: queue.Queue,
|
||||
emitter: Emitter,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -100,19 +93,21 @@ class TestMemoryToolRun:
|
||||
memory="User prefers Python",
|
||||
)
|
||||
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
# The delta packet should be in the queue
|
||||
packet = emitter.bus.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers Python"
|
||||
assert packet.obj.operation == "add"
|
||||
assert packet.obj.memory_id is None
|
||||
assert packet.obj.index is None
|
||||
assert packet.placement == placement
|
||||
|
||||
@patch("onyx.tools.tool_implementations.memory.memory_tool.process_memory_update")
|
||||
def test_run_emits_delta_for_update_operation(
|
||||
self,
|
||||
mock_process: MagicMock,
|
||||
memory_tool: MemoryTool,
|
||||
emitter_queue: queue.Queue,
|
||||
emitter: Emitter,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
) -> None:
|
||||
@@ -124,7 +119,7 @@ class TestMemoryToolRun:
|
||||
memory="User prefers light mode",
|
||||
)
|
||||
|
||||
_key, packet = emitter_queue.get_nowait()
|
||||
packet = emitter.bus.get_nowait()
|
||||
assert isinstance(packet.obj, MemoryToolDelta)
|
||||
assert packet.obj.memory_text == "User prefers light mode"
|
||||
assert packet.obj.operation == "update"
|
||||
|
||||
@@ -203,6 +203,7 @@ prompt_or_default() {
|
||||
local default_value="$2"
|
||||
read_prompt_line "$prompt_text"
|
||||
[[ -z "$REPLY" ]] && REPLY="$default_value"
|
||||
return 0
|
||||
}
|
||||
|
||||
prompt_yn_or_default() {
|
||||
@@ -210,6 +211,7 @@ prompt_yn_or_default() {
|
||||
local default_value="$2"
|
||||
read_prompt_char "$prompt_text"
|
||||
[[ -z "$REPLY" ]] && REPLY="$default_value"
|
||||
return 0
|
||||
}
|
||||
|
||||
confirm_action() {
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.38
|
||||
version: 0.4.39
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
2037
deployment/helm/charts/onyx/dashboards/indexing-pipeline.json
Normal file
2037
deployment/helm/charts/onyx/dashboards/indexing-pipeline.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,77 @@
|
||||
{{- if and .Values.monitoring.serviceMonitors.enabled .Values.vectorDB.enabled }}
|
||||
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
|
||||
---
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.celery_worker_monitoring.deploymentLabels.app }}
|
||||
metrics: "true"
|
||||
endpoints:
|
||||
- port: metrics
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- if gt (int .Values.celery_worker_docfetching.replicaCount) 0 }}
|
||||
---
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.celery_worker_docfetching.deploymentLabels.app }}
|
||||
metrics: "true"
|
||||
endpoints:
|
||||
- port: metrics
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- if gt (int .Values.celery_worker_docprocessing.replicaCount) 0 }}
|
||||
---
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.celery_worker_docprocessing.deploymentLabels.app }}
|
||||
metrics: "true"
|
||||
endpoints:
|
||||
- port: metrics
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -0,0 +1,15 @@
|
||||
{{- if .Values.monitoring.grafana.dashboards.enabled }}
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-indexing-pipeline-dashboard
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
grafana_dashboard: "1"
|
||||
annotations:
|
||||
grafana_folder: "Onyx"
|
||||
data:
|
||||
onyx-indexing-pipeline.json: |
|
||||
{{- .Files.Get "dashboards/indexing-pipeline.json" | nindent 4 }}
|
||||
{{- end }}
|
||||
@@ -256,6 +256,20 @@ tooling:
|
||||
# -- Which client binary to call; change if your image uses a non-default path.
|
||||
psqlBinary: psql
|
||||
|
||||
monitoring:
|
||||
grafana:
|
||||
dashboards:
|
||||
# -- Set to true to deploy Grafana dashboard ConfigMaps for the Onyx indexing pipeline.
|
||||
# Requires kube-prometheus-stack (or equivalent) with the Grafana sidecar enabled and watching this namespace.
|
||||
# The sidecar must be configured with label selector: grafana_dashboard=1
|
||||
enabled: false
|
||||
serviceMonitors:
|
||||
# -- Set to true to deploy ServiceMonitor resources for Celery worker metrics endpoints.
|
||||
# Requires the Prometheus Operator CRDs (included in kube-prometheus-stack).
|
||||
# Use `labels` to match your Prometheus CR's serviceMonitorSelector (e.g. release: onyx-monitoring).
|
||||
enabled: false
|
||||
labels: {}
|
||||
|
||||
serviceAccount:
|
||||
# Specifies whether a service account should be created
|
||||
create: false
|
||||
|
||||
@@ -19,6 +19,10 @@ module "eks" {
|
||||
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
|
||||
enable_cluster_creator_admin_permissions = true
|
||||
|
||||
# Control plane logging
|
||||
cluster_enabled_log_types = var.cluster_enabled_log_types
|
||||
cloudwatch_log_group_retention_in_days = var.cloudwatch_log_group_retention_in_days
|
||||
|
||||
eks_managed_node_group_defaults = {
|
||||
ami_type = "AL2023_x86_64_STANDARD"
|
||||
}
|
||||
|
||||
@@ -161,3 +161,25 @@ variable "rds_db_connect_arn" {
|
||||
description = "Full rds-db:connect ARN to allow (required when enable_rds_iam_for_service_account is true)"
|
||||
default = null
|
||||
}
|
||||
|
||||
variable "cluster_enabled_log_types" {
|
||||
type = list(string)
|
||||
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
|
||||
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
|
||||
|
||||
validation {
|
||||
condition = alltrue([for t in var.cluster_enabled_log_types : contains(["api", "audit", "authenticator", "controllerManager", "scheduler"], t)])
|
||||
error_message = "Each entry must be one of: api, audit, authenticator, controllerManager, scheduler."
|
||||
}
|
||||
}
|
||||
|
||||
variable "cloudwatch_log_group_retention_in_days" {
|
||||
type = number
|
||||
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
|
||||
default = 30
|
||||
|
||||
validation {
|
||||
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.cloudwatch_log_group_retention_in_days)
|
||||
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +54,9 @@ module "postgres" {
|
||||
password = var.postgres_password
|
||||
tags = local.merged_tags
|
||||
enable_rds_iam_auth = var.enable_iam_auth
|
||||
|
||||
backup_retention_period = var.postgres_backup_retention_period
|
||||
backup_window = var.postgres_backup_window
|
||||
}
|
||||
|
||||
module "s3" {
|
||||
@@ -80,6 +83,10 @@ module "eks" {
|
||||
public_cluster_enabled = var.public_cluster_enabled
|
||||
private_cluster_enabled = var.private_cluster_enabled
|
||||
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
|
||||
|
||||
# Control plane logging
|
||||
cluster_enabled_log_types = var.eks_cluster_enabled_log_types
|
||||
cloudwatch_log_group_retention_in_days = var.eks_cloudwatch_log_group_retention_in_days
|
||||
}
|
||||
|
||||
module "waf" {
|
||||
|
||||
@@ -250,3 +250,34 @@ variable "opensearch_subnet_ids" {
|
||||
description = "Subnet IDs for OpenSearch. If empty, uses first 3 private subnets."
|
||||
default = []
|
||||
}
|
||||
|
||||
# RDS Backup Configuration
|
||||
variable "postgres_backup_retention_period" {
|
||||
type = number
|
||||
description = "Number of days to retain automated RDS backups (0 to disable)"
|
||||
default = 7
|
||||
}
|
||||
|
||||
variable "postgres_backup_window" {
|
||||
type = string
|
||||
description = "Preferred UTC time window for automated RDS backups (hh24:mi-hh24:mi)"
|
||||
default = "03:00-04:00"
|
||||
}
|
||||
|
||||
# EKS Control Plane Logging
|
||||
variable "eks_cluster_enabled_log_types" {
|
||||
type = list(string)
|
||||
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
|
||||
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
|
||||
}
|
||||
|
||||
variable "eks_cloudwatch_log_group_retention_in_days" {
|
||||
type = number
|
||||
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
|
||||
default = 30
|
||||
|
||||
validation {
|
||||
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.eks_cloudwatch_log_group_retention_in_days)
|
||||
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,5 +44,79 @@ resource "aws_db_instance" "this" {
|
||||
publicly_accessible = false
|
||||
deletion_protection = true
|
||||
storage_encrypted = true
|
||||
tags = var.tags
|
||||
|
||||
# Automated backups
|
||||
backup_retention_period = var.backup_retention_period
|
||||
backup_window = var.backup_window
|
||||
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
# CloudWatch alarm for CPU utilization monitoring
|
||||
resource "aws_cloudwatch_metric_alarm" "cpu_utilization" {
|
||||
alarm_name = "${var.identifier}-cpu-utilization"
|
||||
alarm_description = "RDS CPU utilization for ${var.identifier}"
|
||||
comparison_operator = "GreaterThanThreshold"
|
||||
evaluation_periods = var.cpu_alarm_evaluation_periods
|
||||
metric_name = "CPUUtilization"
|
||||
namespace = "AWS/RDS"
|
||||
period = var.cpu_alarm_period
|
||||
statistic = "Average"
|
||||
threshold = var.cpu_alarm_threshold
|
||||
treat_missing_data = "missing"
|
||||
|
||||
alarm_actions = var.alarm_actions
|
||||
ok_actions = var.alarm_actions
|
||||
|
||||
dimensions = {
|
||||
DBInstanceIdentifier = aws_db_instance.this.identifier
|
||||
}
|
||||
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
# CloudWatch alarm for disk IO monitoring
|
||||
resource "aws_cloudwatch_metric_alarm" "read_iops" {
|
||||
alarm_name = "${var.identifier}-read-iops"
|
||||
alarm_description = "RDS ReadIOPS for ${var.identifier}"
|
||||
comparison_operator = "GreaterThanThreshold"
|
||||
evaluation_periods = var.iops_alarm_evaluation_periods
|
||||
metric_name = "ReadIOPS"
|
||||
namespace = "AWS/RDS"
|
||||
period = var.iops_alarm_period
|
||||
statistic = "Average"
|
||||
threshold = var.read_iops_alarm_threshold
|
||||
treat_missing_data = "missing"
|
||||
|
||||
alarm_actions = var.alarm_actions
|
||||
ok_actions = var.alarm_actions
|
||||
|
||||
dimensions = {
|
||||
DBInstanceIdentifier = aws_db_instance.this.identifier
|
||||
}
|
||||
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
# CloudWatch alarm for freeable memory monitoring
|
||||
resource "aws_cloudwatch_metric_alarm" "freeable_memory" {
|
||||
alarm_name = "${var.identifier}-freeable-memory"
|
||||
alarm_description = "RDS freeable memory for ${var.identifier}"
|
||||
comparison_operator = "LessThanThreshold"
|
||||
evaluation_periods = var.memory_alarm_evaluation_periods
|
||||
metric_name = "FreeableMemory"
|
||||
namespace = "AWS/RDS"
|
||||
period = var.memory_alarm_period
|
||||
statistic = "Average"
|
||||
threshold = var.memory_alarm_threshold
|
||||
treat_missing_data = "missing"
|
||||
|
||||
alarm_actions = var.alarm_actions
|
||||
ok_actions = var.alarm_actions
|
||||
|
||||
dimensions = {
|
||||
DBInstanceIdentifier = aws_db_instance.this.identifier
|
||||
}
|
||||
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
@@ -67,3 +67,131 @@ variable "enable_rds_iam_auth" {
|
||||
description = "Enable AWS IAM database authentication for this RDS instance"
|
||||
default = false
|
||||
}
|
||||
|
||||
variable "backup_retention_period" {
|
||||
type = number
|
||||
description = "Number of days to retain automated backups (0 to disable)"
|
||||
default = 7
|
||||
|
||||
validation {
|
||||
condition = var.backup_retention_period >= 0 && var.backup_retention_period <= 35
|
||||
error_message = "backup_retention_period must be between 0 and 35 (AWS RDS limit)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "backup_window" {
|
||||
type = string
|
||||
description = "Preferred UTC time window for automated backups (hh24:mi-hh24:mi)"
|
||||
default = "03:00-04:00"
|
||||
|
||||
validation {
|
||||
condition = can(regex("^([01]\\d|2[0-3]):[0-5]\\d-([01]\\d|2[0-3]):[0-5]\\d$", var.backup_window))
|
||||
error_message = "backup_window must be in hh24:mi-hh24:mi format (e.g. \"03:00-04:00\")."
|
||||
}
|
||||
}
|
||||
|
||||
# CloudWatch CPU alarm configuration
|
||||
variable "cpu_alarm_threshold" {
|
||||
type = number
|
||||
description = "CPU utilization percentage threshold for the CloudWatch alarm"
|
||||
default = 80
|
||||
|
||||
validation {
|
||||
condition = var.cpu_alarm_threshold >= 0 && var.cpu_alarm_threshold <= 100
|
||||
error_message = "cpu_alarm_threshold must be between 0 and 100 (percentage)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "cpu_alarm_evaluation_periods" {
|
||||
type = number
|
||||
description = "Number of consecutive periods the threshold must be breached before alarming"
|
||||
default = 3
|
||||
|
||||
validation {
|
||||
condition = var.cpu_alarm_evaluation_periods >= 1
|
||||
error_message = "cpu_alarm_evaluation_periods must be at least 1."
|
||||
}
|
||||
}
|
||||
|
||||
variable "cpu_alarm_period" {
|
||||
type = number
|
||||
description = "Period in seconds over which the CPU metric is evaluated"
|
||||
default = 300
|
||||
|
||||
validation {
|
||||
condition = var.cpu_alarm_period >= 60 && var.cpu_alarm_period % 60 == 0
|
||||
error_message = "cpu_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "memory_alarm_threshold" {
|
||||
type = number
|
||||
description = "Freeable memory threshold in bytes. Alarm fires when memory drops below this value."
|
||||
default = 256000000 # 256 MB
|
||||
|
||||
validation {
|
||||
condition = var.memory_alarm_threshold > 0
|
||||
error_message = "memory_alarm_threshold must be greater than 0."
|
||||
}
|
||||
}
|
||||
|
||||
variable "memory_alarm_evaluation_periods" {
|
||||
type = number
|
||||
description = "Number of consecutive periods the threshold must be breached before alarming"
|
||||
default = 3
|
||||
|
||||
validation {
|
||||
condition = var.memory_alarm_evaluation_periods >= 1
|
||||
error_message = "memory_alarm_evaluation_periods must be at least 1."
|
||||
}
|
||||
}
|
||||
|
||||
variable "memory_alarm_period" {
|
||||
type = number
|
||||
description = "Period in seconds over which the freeable memory metric is evaluated"
|
||||
default = 300
|
||||
|
||||
validation {
|
||||
condition = var.memory_alarm_period >= 60 && var.memory_alarm_period % 60 == 0
|
||||
error_message = "memory_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "read_iops_alarm_threshold" {
|
||||
type = number
|
||||
description = "ReadIOPS threshold. Alarm fires when IOPS exceeds this value."
|
||||
default = 3000
|
||||
|
||||
validation {
|
||||
condition = var.read_iops_alarm_threshold > 0
|
||||
error_message = "read_iops_alarm_threshold must be greater than 0."
|
||||
}
|
||||
}
|
||||
|
||||
variable "iops_alarm_evaluation_periods" {
|
||||
type = number
|
||||
description = "Number of consecutive periods the IOPS threshold must be breached before alarming"
|
||||
default = 3
|
||||
|
||||
validation {
|
||||
condition = var.iops_alarm_evaluation_periods >= 1
|
||||
error_message = "iops_alarm_evaluation_periods must be at least 1."
|
||||
}
|
||||
}
|
||||
|
||||
variable "iops_alarm_period" {
|
||||
type = number
|
||||
description = "Period in seconds over which the IOPS metric is evaluated"
|
||||
default = 300
|
||||
|
||||
validation {
|
||||
condition = var.iops_alarm_period >= 60 && var.iops_alarm_period % 60 == 0
|
||||
error_message = "iops_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
|
||||
}
|
||||
}
|
||||
|
||||
variable "alarm_actions" {
|
||||
type = list(string)
|
||||
description = "List of ARNs to notify when the alarm transitions state (e.g. SNS topic ARNs)"
|
||||
default = []
|
||||
}
|
||||
|
||||
349
profiling/grafana/dashboards/onyx/opensearch-search-latency.json
Normal file
349
profiling/grafana/dashboards/onyx/opensearch-search-latency.json
Normal file
@@ -0,0 +1,349 @@
|
||||
{
|
||||
"annotations": {
|
||||
"list": [
|
||||
{
|
||||
"builtIn": 1,
|
||||
"datasource": { "type": "grafana", "uid": "-- Grafana --" },
|
||||
"enable": true,
|
||||
"hide": true,
|
||||
"iconColor": "rgba(0, 211, 255, 1)",
|
||||
"name": "Annotations & Alerts",
|
||||
"type": "dashboard"
|
||||
}
|
||||
]
|
||||
},
|
||||
"editable": true,
|
||||
"fiscalYearStartMonth": 0,
|
||||
"graphTooltip": 1,
|
||||
"id": null,
|
||||
"links": [],
|
||||
"liveNow": true,
|
||||
"panels": [
|
||||
{
|
||||
"title": "Client-Side Search Latency (P50 / P95 / P99)",
|
||||
"description": "End-to-end latency as measured by the Python client, including network round-trip and serialization overhead.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 0 },
|
||||
"id": 1,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "seconds",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "none" },
|
||||
"thresholdsStyle": { "mode": "dashed" }
|
||||
},
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{ "color": "green", "value": null },
|
||||
{ "color": "yellow", "value": 0.5 },
|
||||
{ "color": "red", "value": 2.0 }
|
||||
]
|
||||
},
|
||||
"unit": "s",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P50",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P95",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P99",
|
||||
"refId": "C"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Server-Side Search Latency (P50 / P95 / P99)",
|
||||
"description": "OpenSearch server-side execution time from the 'took' field in the response. Does not include network or client-side overhead.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 0 },
|
||||
"id": 2,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "seconds",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "none" },
|
||||
"thresholdsStyle": { "mode": "dashed" }
|
||||
},
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{ "color": "green", "value": null },
|
||||
{ "color": "yellow", "value": 0.5 },
|
||||
{ "color": "red", "value": 2.0 }
|
||||
]
|
||||
},
|
||||
"unit": "s",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P50",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P95",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "P99",
|
||||
"refId": "C"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Client-Side Latency by Search Type (P95)",
|
||||
"description": "P95 client-side latency broken down by search type (hybrid, keyword, semantic, random, doc_id_retrieval).",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 10 },
|
||||
"id": 3,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "seconds",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "none" },
|
||||
"thresholdsStyle": { "mode": "off" }
|
||||
},
|
||||
"unit": "s",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (search_type, le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "{{ search_type }}",
|
||||
"refId": "A"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Search Throughput by Type",
|
||||
"description": "Searches per second broken down by search type.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 10 },
|
||||
"id": 4,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "searches/s",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "normal" },
|
||||
"thresholdsStyle": { "mode": "off" }
|
||||
},
|
||||
"unit": "ops",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (search_type) (rate(onyx_opensearch_search_total[5m]))",
|
||||
"legendFormat": "{{ search_type }}",
|
||||
"refId": "A"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Concurrent Searches In Progress",
|
||||
"description": "Number of OpenSearch searches currently in flight, broken down by search type. Summed across all instances.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 0, "y": 20 },
|
||||
"id": 5,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "searches",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "normal" },
|
||||
"thresholdsStyle": { "mode": "off" }
|
||||
},
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (search_type) (onyx_opensearch_searches_in_progress)",
|
||||
"legendFormat": "{{ search_type }}",
|
||||
"refId": "A"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Client vs Server Latency Overhead (P50)",
|
||||
"description": "Difference between client-side and server-side P50 latency. Reveals network, serialization, and untracked OpenSearch overhead.",
|
||||
"type": "timeseries",
|
||||
"gridPos": { "h": 10, "w": 12, "x": 12, "y": 20 },
|
||||
"id": 6,
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisLabel": "seconds",
|
||||
"axisPlacement": "auto",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"lineInterpolation": "smooth",
|
||||
"lineWidth": 2,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": { "type": "linear" },
|
||||
"showPoints": "never",
|
||||
"spanNulls": false,
|
||||
"stacking": { "group": "A", "mode": "none" },
|
||||
"thresholdsStyle": { "mode": "off" }
|
||||
},
|
||||
"unit": "s",
|
||||
"min": 0
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m]))) - histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "Client - Server overhead (P50)",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "Client P50",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "Server P50",
|
||||
"refId": "C"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"refresh": "5s",
|
||||
"schemaVersion": 37,
|
||||
"style": "dark",
|
||||
"tags": ["onyx", "opensearch", "search", "latency"],
|
||||
"templating": {
|
||||
"list": [
|
||||
{
|
||||
"current": {
|
||||
"text": "Prometheus",
|
||||
"value": "prometheus"
|
||||
},
|
||||
"includeAll": false,
|
||||
"name": "DS_PROMETHEUS",
|
||||
"options": [],
|
||||
"query": "prometheus",
|
||||
"refresh": 1,
|
||||
"type": "datasource"
|
||||
}
|
||||
]
|
||||
},
|
||||
"time": { "from": "now-60m", "to": "now" },
|
||||
"timepicker": {
|
||||
"refresh_intervals": ["5s", "10s", "30s", "1m"]
|
||||
},
|
||||
"timezone": "",
|
||||
"title": "Onyx OpenSearch Search Latency",
|
||||
"uid": "onyx-opensearch-search-latency",
|
||||
"version": 0,
|
||||
"weekStart": ""
|
||||
}
|
||||
@@ -73,11 +73,17 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
|
||||
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
|
||||
|
||||
ARG SENTRY_RELEASE
|
||||
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
|
||||
|
||||
# Add NODE_OPTIONS argument
|
||||
ARG NODE_OPTIONS
|
||||
|
||||
# SENTRY_AUTH_TOKEN is injected via BuildKit secret mount so it is never written
|
||||
# to any image layer, build cache, or registry manifest.
|
||||
# Use NODE_OPTIONS in the build command
|
||||
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
|
||||
RUN --mount=type=secret,id=sentry_auth_token,env=SENTRY_AUTH_TOKEN \
|
||||
NODE_OPTIONS="${NODE_OPTIONS}" npx next build
|
||||
|
||||
# Step 2. Production image, copy all the files and run next
|
||||
FROM base AS runner
|
||||
@@ -150,6 +156,9 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
|
||||
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
|
||||
|
||||
ARG SENTRY_RELEASE
|
||||
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
@@ -24,6 +24,7 @@ type TextFont =
|
||||
| "secondary-body"
|
||||
| "secondary-action"
|
||||
| "secondary-mono"
|
||||
| "secondary-mono-label"
|
||||
| "figure-small-label"
|
||||
| "figure-small-value"
|
||||
| "figure-keystroke";
|
||||
@@ -88,6 +89,7 @@ const FONT_CONFIG: Record<TextFont, string> = {
|
||||
"secondary-body": "font-secondary-body",
|
||||
"secondary-action": "font-secondary-action",
|
||||
"secondary-mono": "font-secondary-mono",
|
||||
"secondary-mono-label": "font-secondary-mono-label",
|
||||
"figure-small-label": "font-figure-small-label",
|
||||
"figure-small-value": "font-figure-small-value",
|
||||
"figure-keystroke": "font-figure-keystroke",
|
||||
|
||||
@@ -8,6 +8,7 @@ import * as Sentry from "@sentry/nextjs";
|
||||
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
|
||||
Sentry.init({
|
||||
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
|
||||
release: process.env.SENTRY_RELEASE,
|
||||
// Only capture unhandled exceptions
|
||||
tracesSampleRate: 0,
|
||||
debug: false,
|
||||
|
||||
@@ -7,6 +7,7 @@ import * as Sentry from "@sentry/nextjs";
|
||||
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
|
||||
Sentry.init({
|
||||
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
|
||||
release: process.env.SENTRY_RELEASE,
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
|
||||
@@ -1,7 +1 @@
|
||||
"use client";
|
||||
|
||||
import CodeInterpreterPage from "@/refresh-pages/admin/CodeInterpreterPage";
|
||||
|
||||
export default function Page() {
|
||||
return <CodeInterpreterPage />;
|
||||
}
|
||||
export { default } from "@/refresh-pages/admin/CodeInterpreterPage";
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
|
||||
import { ImageProvider } from "@/app/admin/configuration/image-generation/constants";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import { ImageGenerationConfigView } from "@/lib/configuration/imageConfigurationService";
|
||||
import { getImageGenForm } from "./forms";
|
||||
|
||||
interface Props {
|
||||
modal: ModalCreationInterface;
|
||||
imageProvider: ImageProvider;
|
||||
existingProviders: LLMProviderView[];
|
||||
existingConfig?: ImageGenerationConfigView;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Modal for creating/editing image generation configurations.
|
||||
* Routes to provider-specific forms based on imageProvider.provider_name.
|
||||
*/
|
||||
export default function ImageGenerationConnectionModal(props: Props) {
|
||||
return <>{getImageGenForm(props)}</>;
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
export * from "./types";
|
||||
export { ImageGenFormWrapper } from "./ImageGenFormWrapper";
|
||||
export { OpenAIImageGenForm } from "./OpenAIImageGenForm";
|
||||
export { AzureImageGenForm } from "./AzureImageGenForm";
|
||||
export { getImageGenForm } from "./getImageGenForm";
|
||||
@@ -1,22 +1 @@
|
||||
"use client";
|
||||
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import ImageGenerationContent from "./ImageGenerationContent";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTES.IMAGE_GENERATION;
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={route.icon}
|
||||
title={route.title}
|
||||
description="Settings for in-chat image generation."
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<ImageGenerationContent />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
export { default } from "@/refresh-pages/admin/ImageGenerationPage";
|
||||
|
||||
@@ -1,7 +1 @@
|
||||
"use client";
|
||||
|
||||
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
export default function Page() {
|
||||
return <LLMConfigurationPage />;
|
||||
}
|
||||
export { default } from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { Form, Formik } from "formik";
|
||||
import { mutate } from "swr";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import * as Yup from "yup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
@@ -119,6 +121,10 @@ export const DocumentSetCreationForm = ({
|
||||
? "Successfully updated document set!"
|
||||
: "Successfully created document set!"
|
||||
);
|
||||
await Promise.all([
|
||||
mutate(SWR_KEYS.documentSets),
|
||||
mutate(SWR_KEYS.documentSetsEditable),
|
||||
]);
|
||||
onClose();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { DocumentSetSummary } from "@/lib/types";
|
||||
import useSWR, { mutate } from "swr";
|
||||
|
||||
const DOCUMENT_SETS_URL = "/api/manage/document-set";
|
||||
const GET_EDITABLE_DOCUMENT_SETS_URL =
|
||||
"/api/manage/document-set?get_editable=true";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
|
||||
export function refreshDocumentSets() {
|
||||
mutate(DOCUMENT_SETS_URL);
|
||||
mutate(SWR_KEYS.documentSets);
|
||||
}
|
||||
|
||||
export function useDocumentSets(getEditable: boolean = false) {
|
||||
const url = getEditable ? GET_EDITABLE_DOCUMENT_SETS_URL : DOCUMENT_SETS_URL;
|
||||
const url = getEditable
|
||||
? SWR_KEYS.documentSetsEditable
|
||||
: SWR_KEYS.documentSets;
|
||||
|
||||
const swrResponse = useSWR<DocumentSetSummary[]>(url, errorHandlingFetcher, {
|
||||
refreshInterval: 5000, // 5 seconds
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
export { default } from "@/refresh-pages/admin/HooksPage";
|
||||
@@ -182,8 +182,7 @@ export async function* sendMessage({
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const data = await response.json().catch(() => ({}));
|
||||
throw new Error(data.detail ?? `HTTP error! status: ${response.status}`);
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
yield* handleSSEStream<PacketType>(response, signal);
|
||||
|
||||
1
web/src/app/ee/admin/hooks/page.tsx
Normal file
1
web/src/app/ee/admin/hooks/page.tsx
Normal file
@@ -0,0 +1 @@
|
||||
export { default } from "@/ee/refresh-pages/admin/HooksPage";
|
||||
@@ -11,6 +11,7 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { ChatSessionMinimal } from "@/app/ee/admin/performance/usage/types";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { timestampToReadableDate } from "@/lib/dateUtils";
|
||||
import { Dispatch, SetStateAction, useCallback, useState } from "react";
|
||||
import { Feedback, TaskStatus } from "@/lib/types";
|
||||
@@ -101,34 +102,32 @@ function SelectFeedbackType({
|
||||
onValueChange: (value: Feedback | "all") => void;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
<Text as="p" className="my-auto mr-2 font-medium mb-1">
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" className="font-medium">
|
||||
Feedback Type
|
||||
</Text>
|
||||
<div className="max-w-sm space-y-6">
|
||||
<InputSelect
|
||||
value={value}
|
||||
onValueChange={onValueChange as (value: string) => void}
|
||||
>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect
|
||||
value={value}
|
||||
onValueChange={onValueChange as (value: string) => void}
|
||||
>
|
||||
<InputSelect.Trigger />
|
||||
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="all" icon={SvgMinusCircle}>
|
||||
Any
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="like" icon={SvgThumbsUp}>
|
||||
Like
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="dislike" icon={SvgThumbsDown}>
|
||||
Dislike
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="mixed" icon={SvgMinus}>
|
||||
Mixed
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</div>
|
||||
</div>
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="all" icon={SvgMinusCircle}>
|
||||
Any
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="like" icon={SvgThumbsUp}>
|
||||
Like
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="dislike" icon={SvgThumbsDown}>
|
||||
Dislike
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="mixed" icon={SvgMinus}>
|
||||
Mixed
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -185,60 +184,61 @@ function PreviousQueryHistoryExportsModal({
|
||||
onClose={() => setShowModal(false)}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<div className="flex flex-col w-full">
|
||||
<div className="flex flex-1">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Generated At</TableHead>
|
||||
<TableHead>Start Range</TableHead>
|
||||
<TableHead>End Range</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Download</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{paginatedTasks.map((task, index) => (
|
||||
<TableRow key={index}>
|
||||
<TableCell>
|
||||
{humanReadableFormatWithTime(task.startTime)}
|
||||
</TableCell>
|
||||
<TableCell>{task.start.toDateString()}</TableCell>
|
||||
<TableCell>{task.end.toDateString()}</TableCell>
|
||||
<TableCell>
|
||||
<ExportBadge status={task.status} />
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{task.status === "SUCCESS" ? (
|
||||
<a
|
||||
className="flex justify-center"
|
||||
href={withRequestId(
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Generated At</TableHead>
|
||||
<TableHead>Start Range</TableHead>
|
||||
<TableHead>End Range</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Download</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{paginatedTasks.map((task, index) => (
|
||||
<TableRow key={index}>
|
||||
<TableCell>
|
||||
{humanReadableFormatWithTime(task.startTime)}
|
||||
</TableCell>
|
||||
<TableCell>{task.start.toDateString()}</TableCell>
|
||||
<TableCell>{task.end.toDateString()}</TableCell>
|
||||
<TableCell>
|
||||
<ExportBadge status={task.status} />
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
variant="default"
|
||||
prominence="tertiary"
|
||||
icon={SvgDownloadCloud}
|
||||
size="sm"
|
||||
disabled={task.status !== "SUCCESS"}
|
||||
tooltip={
|
||||
task.status !== "SUCCESS"
|
||||
? "Export is not yet ready"
|
||||
: undefined
|
||||
}
|
||||
href={
|
||||
task.status === "SUCCESS"
|
||||
? withRequestId(
|
||||
DOWNLOAD_QUERY_HISTORY_URL,
|
||||
task.taskId
|
||||
)}
|
||||
>
|
||||
<SvgDownloadCloud className="h-4 w-4 text-action-link-05" />
|
||||
</a>
|
||||
) : (
|
||||
<SvgDownloadCloud className="h-4 w-4 text-action-link-05 opacity-20" />
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<div className="flex mt-3">
|
||||
<div className="mx-auto">
|
||||
<PageSelector
|
||||
currentPage={taskPage}
|
||||
totalPages={totalTaskPages}
|
||||
onPageChange={setTaskPage}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<Section>
|
||||
<PageSelector
|
||||
currentPage={taskPage}
|
||||
totalPages={totalTaskPages}
|
||||
onPageChange={setTaskPage}
|
||||
/>
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
@@ -330,48 +330,48 @@ export function QueryHistoryTable() {
|
||||
</div>
|
||||
</div>
|
||||
<Separator />
|
||||
<Table className="mt-5">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>First User Message</TableHead>
|
||||
<TableHead>First AI Response</TableHead>
|
||||
<TableHead>Feedback</TableHead>
|
||||
<TableHead>User</TableHead>
|
||||
<TableHead>Persona</TableHead>
|
||||
<TableHead>Date</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
{isLoading ? (
|
||||
<TableBody>
|
||||
<Section>
|
||||
<Table className="mt-5">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableCell colSpan={6} className="text-center">
|
||||
<ThreeDotsLoader />
|
||||
</TableCell>
|
||||
<TableHead>First User Message</TableHead>
|
||||
<TableHead>First AI Response</TableHead>
|
||||
<TableHead>Feedback</TableHead>
|
||||
<TableHead>User</TableHead>
|
||||
<TableHead>Persona</TableHead>
|
||||
<TableHead>Date</TableHead>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
) : (
|
||||
<TableBody>
|
||||
{chatSessionData?.map((chatSessionMinimal) => (
|
||||
<QueryHistoryTableRow
|
||||
key={chatSessionMinimal.id}
|
||||
chatSessionMinimal={chatSessionMinimal}
|
||||
/>
|
||||
))}
|
||||
</TableBody>
|
||||
)}
|
||||
</Table>
|
||||
</TableHeader>
|
||||
{isLoading ? (
|
||||
<TableBody>
|
||||
<TableRow>
|
||||
<TableCell colSpan={6} className="text-center">
|
||||
<ThreeDotsLoader />
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
) : (
|
||||
<TableBody>
|
||||
{chatSessionData?.map((chatSessionMinimal) => (
|
||||
<QueryHistoryTableRow
|
||||
key={chatSessionMinimal.id}
|
||||
chatSessionMinimal={chatSessionMinimal}
|
||||
/>
|
||||
))}
|
||||
</TableBody>
|
||||
)}
|
||||
</Table>
|
||||
|
||||
{chatSessionData && (
|
||||
<div className="mt-3 flex">
|
||||
<div className="mx-auto">
|
||||
{chatSessionData && (
|
||||
<Section>
|
||||
<PageSelector
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={goToPage}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Section>
|
||||
)}
|
||||
</Section>
|
||||
</CardSection>
|
||||
|
||||
{showModal && (
|
||||
|
||||
@@ -330,6 +330,14 @@
|
||||
letter-spacing: 0px;
|
||||
}
|
||||
|
||||
.font-secondary-mono-label {
|
||||
font-family: var(--font-dm-mono);
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
line-height: 16px;
|
||||
letter-spacing: 0px;
|
||||
}
|
||||
|
||||
/* FIGURE STYLES */
|
||||
|
||||
.font-figure-small-label {
|
||||
|
||||
@@ -206,6 +206,15 @@ export function CCPairStatus({
|
||||
Indexing
|
||||
</Badge>
|
||||
);
|
||||
} else if (
|
||||
lastIndexAttemptStatus &&
|
||||
lastIndexAttemptStatus === "not_started"
|
||||
) {
|
||||
badge = (
|
||||
<Badge variant="not_started" icon={FiClock}>
|
||||
Scheduled
|
||||
</Badge>
|
||||
);
|
||||
} else if (
|
||||
lastIndexAttemptStatus &&
|
||||
lastIndexAttemptStatus === "canceled"
|
||||
|
||||
42
web/src/ee/hooks/useHookExecutionLogs.ts
Normal file
42
web/src/ee/hooks/useHookExecutionLogs.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import useSWR from "swr";
|
||||
import { fetchExecutionLogs } from "@/ee/refresh-pages/admin/HooksPage/svc";
|
||||
import type { HookExecutionRecord } from "@/ee/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
const ONE_HOUR_MS = 60 * 60 * 1000;
|
||||
const THIRTY_DAYS_MS = 30 * 24 * 60 * 60 * 1000;
|
||||
|
||||
interface UseHookExecutionLogsResult {
|
||||
isLoading: boolean;
|
||||
error: Error | undefined;
|
||||
hasRecentErrors: boolean;
|
||||
recentErrors: HookExecutionRecord[];
|
||||
olderErrors: HookExecutionRecord[];
|
||||
}
|
||||
|
||||
export function useHookExecutionLogs(
|
||||
hookId: number,
|
||||
limit = 10
|
||||
): UseHookExecutionLogsResult {
|
||||
const { data, isLoading, error } = useSWR(
|
||||
["hook-execution-logs", hookId, limit],
|
||||
() => fetchExecutionLogs(hookId, limit),
|
||||
{ refreshInterval: 60_000 }
|
||||
);
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const recentErrors =
|
||||
data?.filter(
|
||||
(log) => now - new Date(log.created_at).getTime() < ONE_HOUR_MS
|
||||
) ?? [];
|
||||
|
||||
const olderErrors =
|
||||
data?.filter((log) => {
|
||||
const age = now - new Date(log.created_at).getTime();
|
||||
return age >= ONE_HOUR_MS && age < THIRTY_DAYS_MS;
|
||||
}) ?? [];
|
||||
|
||||
const hasRecentErrors = recentErrors.length > 0;
|
||||
|
||||
return { isLoading, error, hasRecentErrors, recentErrors, olderErrors };
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { HookPointMeta } from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
import { HookPointMeta } from "@/ee/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
export function useHookSpecs() {
|
||||
const { data, isLoading, error } = useSWR<HookPointMeta[]>(
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { HookResponse } from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
import { HookResponse } from "@/ee/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
export function useHooks() {
|
||||
const { data, isLoading, error, mutate } = useSWR<HookResponse[]>(
|
||||
@@ -11,7 +11,6 @@ import Card from "@/refresh-components/cards/Card";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import {
|
||||
SvgCheckCircle,
|
||||
SvgExternalLink,
|
||||
SvgPlug,
|
||||
SvgRefreshCw,
|
||||
@@ -23,14 +22,15 @@ import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import type {
|
||||
HookPointMeta,
|
||||
HookResponse,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
} from "@/ee/refresh-pages/admin/HooksPage/interfaces";
|
||||
import {
|
||||
activateHook,
|
||||
deactivateHook,
|
||||
deleteHook,
|
||||
validateHook,
|
||||
} from "@/refresh-pages/admin/HooksPage/svc";
|
||||
import { getHookPointIcon } from "@/refresh-pages/admin/HooksPage/hookPointIcons";
|
||||
} from "@/ee/refresh-pages/admin/HooksPage/svc";
|
||||
import { getHookPointIcon } from "@/ee/refresh-pages/admin/HooksPage/hookPointIcons";
|
||||
import HookStatusPopover from "@/ee/refresh-pages/admin/HooksPage/HookStatusPopover";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sub-component: disconnect confirmation modal
|
||||
@@ -328,7 +328,7 @@ export default function ConnectedHookCard({
|
||||
href={spec.docs_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="pl-6 flex items-center gap-1"
|
||||
className="pl-6 flex items-center gap-1 w-fit"
|
||||
>
|
||||
<span className="underline font-secondary-body text-text-03">
|
||||
Documentation
|
||||
@@ -345,21 +345,13 @@ export default function ConnectedHookCard({
|
||||
height="fit"
|
||||
gap={0}
|
||||
>
|
||||
<div className="flex items-center gap-1 p-2">
|
||||
<div className="flex items-center gap-1">
|
||||
{hook.is_active ? (
|
||||
<>
|
||||
<Text mainUiAction text03>
|
||||
Connected
|
||||
</Text>
|
||||
<SvgCheckCircle
|
||||
size={16}
|
||||
className="text-status-success-05"
|
||||
/>
|
||||
</>
|
||||
<HookStatusPopover hook={hook} spec={spec} isBusy={isBusy} />
|
||||
) : (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center gap-1",
|
||||
"flex items-center gap-1 p-2",
|
||||
isBusy ? "opacity-50 pointer-events-none" : "cursor-pointer"
|
||||
)}
|
||||
onClick={handleActivate}
|
||||
@@ -23,14 +23,14 @@ import {
|
||||
HookAuthError,
|
||||
HookTimeoutError,
|
||||
HookConnectError,
|
||||
} from "@/refresh-pages/admin/HooksPage/svc";
|
||||
} from "@/ee/refresh-pages/admin/HooksPage/svc";
|
||||
import type {
|
||||
HookFailStrategy,
|
||||
HookFormState,
|
||||
HookPointMeta,
|
||||
HookResponse,
|
||||
HookUpdateRequest,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
} from "@/ee/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user