Compare commits

..

1 Commits

Author SHA1 Message Date
Bo-Onyx
d853621e2d feat(hook): make hook ee feature 2026-03-30 17:13:26 -07:00
79 changed files with 2781 additions and 4922 deletions

View File

@@ -704,9 +704,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
@@ -789,9 +786,6 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
SENTRY_RELEASE=${{ github.sha }}
secrets: |
sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest

View File

@@ -5,6 +5,7 @@ from onyx.background.celery.apps.primary import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.hooks",
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",

View File

@@ -55,6 +55,15 @@ ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,

View File

@@ -69,5 +69,7 @@ EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
"/admin/token-rate-limits",
# Evals
"/evals",
# Hook extensions
"/admin/hooks",
}
)

View File

View File

@@ -0,0 +1,385 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
T = TypeVar("T", bound=BaseModel)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if MULTI_TENANT:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
def _execute_hook_impl(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -15,6 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.features.hooks.api import router as hook_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.license_enforcement import (
@@ -138,6 +139,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
include_router_with_global_prefix_prepended(application, evals_router)
include_router_with_global_prefix_prepended(application, hook_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -317,7 +317,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.hooks",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -14,7 +14,6 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -362,19 +361,6 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if HOOKS_AVAILABLE:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
"schedule": timedelta(days=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float

View File

@@ -1079,7 +1079,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"

View File

@@ -1,4 +1,3 @@
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from shared_configs.configs import MULTI_TENANT
@@ -7,10 +6,7 @@ from shared_configs.configs import MULTI_TENANT
def require_hook_enabled() -> None:
"""FastAPI dependency that gates all hook management endpoints.
Hooks are only available in single-tenant / self-hosted deployments with
HOOK_ENABLED=true explicitly set. Two layers of protection:
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
2. HOOK_ENABLED flag — explicit opt-in by the operator
Hooks are only available in single-tenant / self-hosted EE deployments.
Use as: Depends(require_hook_enabled)
"""
@@ -19,8 +15,3 @@ def require_hook_enabled() -> None:
OnyxErrorCode.SINGLE_TENANT_ONLY,
"Hooks are not available in multi-tenant deployments",
)
if not HOOK_ENABLED:
raise OnyxError(
OnyxErrorCode.ENV_VAR_GATED,
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
)

View File

@@ -1,79 +1,22 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
"""CE hook executor.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
response_type=QueryProcessingResponse,
)
HookSkipped and HookSoftFailed are real classes kept here because
process_message.py (CE code) uses isinstance checks against them.
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (response_type)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
execute_hook is the public entry point. It dispatches to _execute_hook_impl
via fetch_versioned_implementation so that:
- CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped()
- EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call
"""
import json
import time
from typing import Any
from typing import TypeVar
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
from onyx.utils.variable_functionality import fetch_versioned_implementation
class HookSkipped:
@@ -87,277 +30,15 @@ class HookSoftFailed:
T = TypeVar("T", bound=BaseModel)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
def _execute_hook_impl(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against response_type.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: T | None = None
if outcome.is_success and outcome.response_payload is not None:
try:
validated_model = response_type.model_validate(outcome.response_payload)
except ValidationError as e:
msg = (
f"Hook response failed validation against {response_type.__name__}: {e}"
)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return validated_model
db_session: Session, # noqa: ARG001
hook_point: HookPoint, # noqa: ARG001
payload: dict[str, Any], # noqa: ARG001
response_type: type[T], # noqa: ARG001
) -> T | HookSkipped | HookSoftFailed:
"""CE no-op — hooks are not available without EE."""
return HookSkipped()
def execute_hook(
@@ -367,25 +48,15 @@ def execute_hook(
payload: dict[str, Any],
response_type: type[T],
) -> T | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
"""Execute the hook for the given hook point.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
Dispatches to the versioned implementation so EE gets the real executor
and CE gets the no-op stub, without any changes at the call site.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload, response_type)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise
impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl")
return impl(
db_session=db_session,
hook_point=hook_point,
payload=payload,
response_type=response_type,
)

View File

@@ -1,5 +0,0 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -77,7 +77,6 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -455,7 +454,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -21,7 +21,6 @@ from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
@@ -38,6 +37,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -98,7 +98,7 @@ def fetch_settings(
needs_reindexing=needs_reindexing,
onyx_craft_enabled=onyx_craft_enabled_for_user,
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
hooks_enabled=not MULTI_TENANT,
version=onyx_version,
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
default_user_file_max_upload_size_mb=min(

View File

@@ -116,7 +116,7 @@ class UserSettings(Settings):
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
# document sets are unavailable.
vector_db_enabled: bool = True
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
# True when hooks are available: single-tenant EE deployments only.
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None

View File

@@ -9,11 +9,11 @@ import httpx
import pytest
from pydantic import BaseModel
from ee.onyx.hooks.executor import _execute_hook_impl as execute_hook
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
@@ -118,28 +118,30 @@ def db_session() -> MagicMock:
@pytest.mark.parametrize(
"hooks_available,hook",
"multi_tenant,hook",
[
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
pytest.param(False, None, id="hooks_not_available"),
pytest.param(True, None, id="hook_not_found"),
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
# MULTI_TENANT=True exits before the DB lookup — hook is irrelevant.
pytest.param(True, None, id="multi_tenant"),
pytest.param(False, None, id="hook_not_found"),
pytest.param(False, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(False, _make_hook(endpoint_url=None), id="no_endpoint_url"),
],
)
def test_early_exit_returns_skipped_with_no_db_writes(
db_session: MagicMock,
hooks_available: bool,
multi_tenant: bool,
hook: MagicMock | None,
) -> None:
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
patch("ee.onyx.hooks.executor.MULTI_TENANT", multi_tenant),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
):
result = execute_hook(
db_session=db_session,
@@ -164,14 +166,16 @@ def test_success_returns_validated_model_and_sets_reachable(
hook = _make_hook()
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
@@ -195,14 +199,14 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
hook = _make_hook(is_reachable=True)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
@@ -224,14 +228,16 @@ def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
@@ -258,14 +264,16 @@ def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
@@ -384,14 +392,14 @@ def test_http_failure_paths(
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=exception)
@@ -443,14 +451,14 @@ def test_authorization_header(
hook = _make_hook(api_key=api_key)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit"),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit"),
patch("ee.onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
mock_client = _setup_client(mock_client_cls, response=_make_response())
@@ -489,13 +497,13 @@ def test_persist_session_failure_is_swallowed(
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor.get_session_with_current_tenant",
"ee.onyx.hooks.executor.get_session_with_current_tenant",
side_effect=RuntimeError("DB unavailable"),
),
patch("httpx.Client") as mock_client_cls,
@@ -556,14 +564,16 @@ def test_response_validation_failure_respects_fail_strategy(
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
@@ -619,13 +629,13 @@ def test_unexpected_exception_in_inner_respects_fail_strategy(
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
"ee.onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
@@ -658,17 +668,19 @@ def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> Non
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch("ee.onyx.hooks.executor.MULTI_TENANT", False),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
"ee.onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("ee.onyx.hooks.executor.get_session_with_current_tenant"),
patch(
"onyx.hooks.executor.update_hook__no_commit",
"ee.onyx.hooks.executor.update_hook__no_commit",
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch(
"ee.onyx.hooks.executor.create_hook_execution_log__no_commit"
) as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))

View File

@@ -1,4 +1,4 @@
"""Unit tests for onyx.server.features.hooks.api helpers.
"""Unit tests for ee.onyx.server.features.hooks.api helpers.
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
@@ -16,13 +16,13 @@ from unittest.mock import patch
import httpx
import pytest
from ee.onyx.server.features.hooks.api import _check_ssrf_safety
from ee.onyx.server.features.hooks.api import _raise_for_validation_failure
from ee.onyx.server.features.hooks.api import _validate_endpoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.server.features.hooks.api import _check_ssrf_safety
from onyx.server.features.hooks.api import _raise_for_validation_failure
from onyx.server.features.hooks.api import _validate_endpoint
# ---------------------------------------------------------------------------
# Helpers
@@ -117,28 +117,28 @@ class TestCheckSsrfSafety:
class TestValidateEndpoint:
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
with patch("ee.onyx.server.features.hooks.api._check_ssrf_safety"):
return _validate_endpoint(
endpoint_url=_URL,
api_key=api_key,
timeout_seconds=_TIMEOUT,
)
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(200)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(500)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize("status_code", [401, 403])
def test_401_403_returns_auth_failed(
self, mock_client_cls: MagicMock, status_code: int
@@ -150,21 +150,21 @@ class TestValidateEndpoint:
assert result.status == HookValidateStatus.auth_failed
assert str(status_code) in (result.error_message or "")
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(422)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(
"exc",
[
@@ -179,7 +179,7 @@ class TestValidateEndpoint:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
def test_connect_error_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
@@ -189,7 +189,7 @@ class TestValidateEndpoint:
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
def test_arbitrary_exception_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
@@ -198,7 +198,7 @@ class TestValidateEndpoint:
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
@@ -206,7 +206,7 @@ class TestValidateEndpoint:
_, kwargs = mock_post.call_args
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
@patch("onyx.server.features.hooks.api.httpx.Client")
@patch("ee.onyx.server.features.hooks.api.httpx.Client")
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)

View File

@@ -11,30 +11,13 @@ from onyx.hooks.api_dependencies import require_hook_enabled
class TestRequireHookEnabled:
def test_raises_when_multi_tenant(self) -> None:
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", True),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
):
with patch("onyx.hooks.api_dependencies.MULTI_TENANT", True):
with pytest.raises(OnyxError) as exc_info:
require_hook_enabled()
assert exc_info.value.error_code is OnyxErrorCode.SINGLE_TENANT_ONLY
assert exc_info.value.status_code == 403
assert "multi-tenant" in exc_info.value.detail
def test_raises_when_flag_disabled(self) -> None:
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", False),
):
with pytest.raises(OnyxError) as exc_info:
require_hook_enabled()
assert exc_info.value.error_code is OnyxErrorCode.ENV_VAR_GATED
assert exc_info.value.status_code == 403
assert "HOOK_ENABLED" in exc_info.value.detail
def test_passes_when_enabled_single_tenant(self) -> None:
with (
patch("onyx.hooks.api_dependencies.MULTI_TENANT", False),
patch("onyx.hooks.api_dependencies.HOOK_ENABLED", True),
):
def test_passes_when_single_tenant(self) -> None:
with patch("onyx.hooks.api_dependencies.MULTI_TENANT", False):
require_hook_enabled() # must not raise

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.39
version: 0.4.38
appVersion: latest
annotations:
category: Productivity

File diff suppressed because it is too large Load Diff

View File

@@ -1,77 +0,0 @@
{{- if and .Values.monitoring.serviceMonitors.enabled .Values.vectorDB.enabled }}
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-monitoring
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_monitoring.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- if gt (int .Values.celery_worker_docfetching.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docfetching
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_docfetching.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- if gt (int .Values.celery_worker_docprocessing.replicaCount) 0 }}
---
apiVersion: monitoring.coreos.com/v1
kind: ServiceMonitor
metadata:
name: {{ include "onyx.fullname" . }}-celery-worker-docprocessing
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.monitoring.serviceMonitors.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
namespaceSelector:
matchNames:
- {{ .Release.Namespace }}
selector:
matchLabels:
app: {{ .Values.celery_worker_docprocessing.deploymentLabels.app }}
metrics: "true"
endpoints:
- port: metrics
path: /metrics
interval: 30s
scrapeTimeout: 10s
{{- end }}
{{- end }}

View File

@@ -1,15 +0,0 @@
{{- if .Values.monitoring.grafana.dashboards.enabled }}
---
apiVersion: v1
kind: ConfigMap
metadata:
name: {{ include "onyx.fullname" . }}-indexing-pipeline-dashboard
labels:
{{- include "onyx.labels" . | nindent 4 }}
grafana_dashboard: "1"
annotations:
grafana_folder: "Onyx"
data:
onyx-indexing-pipeline.json: |
{{- .Files.Get "dashboards/indexing-pipeline.json" | nindent 4 }}
{{- end }}

View File

@@ -256,20 +256,6 @@ tooling:
# -- Which client binary to call; change if your image uses a non-default path.
psqlBinary: psql
monitoring:
grafana:
dashboards:
# -- Set to true to deploy Grafana dashboard ConfigMaps for the Onyx indexing pipeline.
# Requires kube-prometheus-stack (or equivalent) with the Grafana sidecar enabled and watching this namespace.
# The sidecar must be configured with label selector: grafana_dashboard=1
enabled: false
serviceMonitors:
# -- Set to true to deploy ServiceMonitor resources for Celery worker metrics endpoints.
# Requires the Prometheus Operator CRDs (included in kube-prometheus-stack).
# Use `labels` to match your Prometheus CR's serviceMonitorSelector (e.g. release: onyx-monitoring).
enabled: false
labels: {}
serviceAccount:
# Specifies whether a service account should be created
create: false

View File

@@ -19,10 +19,6 @@ module "eks" {
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
enable_cluster_creator_admin_permissions = true
# Control plane logging
cluster_enabled_log_types = var.cluster_enabled_log_types
cloudwatch_log_group_retention_in_days = var.cloudwatch_log_group_retention_in_days
eks_managed_node_group_defaults = {
ami_type = "AL2023_x86_64_STANDARD"
}

View File

@@ -161,25 +161,3 @@ variable "rds_db_connect_arn" {
description = "Full rds-db:connect ARN to allow (required when enable_rds_iam_for_service_account is true)"
default = null
}
variable "cluster_enabled_log_types" {
type = list(string)
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
validation {
condition = alltrue([for t in var.cluster_enabled_log_types : contains(["api", "audit", "authenticator", "controllerManager", "scheduler"], t)])
error_message = "Each entry must be one of: api, audit, authenticator, controllerManager, scheduler."
}
}
variable "cloudwatch_log_group_retention_in_days" {
type = number
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
default = 30
validation {
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.cloudwatch_log_group_retention_in_days)
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
}
}

View File

@@ -83,10 +83,6 @@ module "eks" {
public_cluster_enabled = var.public_cluster_enabled
private_cluster_enabled = var.private_cluster_enabled
cluster_endpoint_public_access_cidrs = var.cluster_endpoint_public_access_cidrs
# Control plane logging
cluster_enabled_log_types = var.eks_cluster_enabled_log_types
cloudwatch_log_group_retention_in_days = var.eks_cloudwatch_log_group_retention_in_days
}
module "waf" {

View File

@@ -263,21 +263,3 @@ variable "postgres_backup_window" {
description = "Preferred UTC time window for automated RDS backups (hh24:mi-hh24:mi)"
default = "03:00-04:00"
}
# EKS Control Plane Logging
variable "eks_cluster_enabled_log_types" {
type = list(string)
description = "EKS control plane log types to enable (valid: api, audit, authenticator, controllerManager, scheduler)"
default = ["api", "audit", "authenticator", "controllerManager", "scheduler"]
}
variable "eks_cloudwatch_log_group_retention_in_days" {
type = number
description = "Number of days to retain EKS control plane logs in CloudWatch (0 = never expire)"
default = 30
validation {
condition = contains([0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653], var.eks_cloudwatch_log_group_retention_in_days)
error_message = "Must be a valid CloudWatch retention value (0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1096, 1827, 2192, 2557, 2922, 3288, 3653)."
}
}

View File

@@ -51,26 +51,3 @@ resource "aws_db_instance" "this" {
tags = var.tags
}
# CloudWatch alarm for CPU utilization monitoring
resource "aws_cloudwatch_metric_alarm" "cpu_utilization" {
alarm_name = "${var.identifier}-cpu-utilization"
alarm_description = "RDS CPU utilization for ${var.identifier}"
comparison_operator = "GreaterThanThreshold"
evaluation_periods = var.cpu_alarm_evaluation_periods
metric_name = "CPUUtilization"
namespace = "AWS/RDS"
period = var.cpu_alarm_period
statistic = "Average"
threshold = var.cpu_alarm_threshold
treat_missing_data = "missing"
alarm_actions = var.alarm_actions
ok_actions = var.alarm_actions
dimensions = {
DBInstanceIdentifier = aws_db_instance.this.identifier
}
tags = var.tags
}

View File

@@ -89,43 +89,3 @@ variable "backup_window" {
error_message = "backup_window must be in hh24:mi-hh24:mi format (e.g. \"03:00-04:00\")."
}
}
# CloudWatch CPU alarm configuration
variable "cpu_alarm_threshold" {
type = number
description = "CPU utilization percentage threshold for the CloudWatch alarm"
default = 80
validation {
condition = var.cpu_alarm_threshold >= 0 && var.cpu_alarm_threshold <= 100
error_message = "cpu_alarm_threshold must be between 0 and 100 (percentage)."
}
}
variable "cpu_alarm_evaluation_periods" {
type = number
description = "Number of consecutive periods the threshold must be breached before alarming"
default = 3
validation {
condition = var.cpu_alarm_evaluation_periods >= 1
error_message = "cpu_alarm_evaluation_periods must be at least 1."
}
}
variable "cpu_alarm_period" {
type = number
description = "Period in seconds over which the CPU metric is evaluated"
default = 300
validation {
condition = var.cpu_alarm_period >= 60 && var.cpu_alarm_period % 60 == 0
error_message = "cpu_alarm_period must be a multiple of 60 seconds and at least 60 (CloudWatch requirement)."
}
}
variable "alarm_actions" {
type = list(string)
description = "List of ARNs to notify when the alarm transitions state (e.g. SNS topic ARNs)"
default = []
}

View File

@@ -73,17 +73,11 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
ARG SENTRY_RELEASE
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
# Add NODE_OPTIONS argument
ARG NODE_OPTIONS
# SENTRY_AUTH_TOKEN is injected via BuildKit secret mount so it is never written
# to any image layer, build cache, or registry manifest.
# Use NODE_OPTIONS in the build command
RUN --mount=type=secret,id=sentry_auth_token,env=SENTRY_AUTH_TOKEN \
NODE_OPTIONS="${NODE_OPTIONS}" npx next build
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
# Step 2. Production image, copy all the files and run next
FROM base AS runner
@@ -156,9 +150,6 @@ ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
ARG NEXT_PUBLIC_RECAPTCHA_SITE_KEY
ENV NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${NEXT_PUBLIC_RECAPTCHA_SITE_KEY}
ARG SENTRY_RELEASE
ENV SENTRY_RELEASE=${SENTRY_RELEASE}
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION}

View File

@@ -1 +1,7 @@
export { default } from "@/refresh-pages/admin/CodeInterpreterPage";
"use client";
import CodeInterpreterPage from "@/refresh-pages/admin/CodeInterpreterPage";
export default function Page() {
return <CodeInterpreterPage />;
}

View File

@@ -0,0 +1,23 @@
"use client";
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
import { ImageProvider } from "@/app/admin/configuration/image-generation/constants";
import { LLMProviderView } from "@/interfaces/llm";
import { ImageGenerationConfigView } from "@/lib/configuration/imageConfigurationService";
import { getImageGenForm } from "./forms";
interface Props {
modal: ModalCreationInterface;
imageProvider: ImageProvider;
existingProviders: LLMProviderView[];
existingConfig?: ImageGenerationConfigView;
onSuccess: () => void;
}
/**
* Modal for creating/editing image generation configurations.
* Routes to provider-specific forms based on imageProvider.provider_name.
*/
export default function ImageGenerationConnectionModal(props: Props) {
return <>{getImageGenForm(props)}</>;
}

View File

@@ -2,6 +2,7 @@
import { useState, useMemo, useEffect } from "react";
import useSWR from "swr";
import { Select } from "@/refresh-components/cards";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
@@ -10,39 +11,24 @@ import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
import {
IMAGE_PROVIDER_GROUPS,
ImageProvider,
} from "@/refresh-pages/admin/ImageGenerationPage/constants";
} from "@/app/admin/configuration/image-generation/constants";
import ImageGenerationConnectionModal from "@/app/admin/configuration/image-generation/ImageGenerationConnectionModal";
import {
ImageGenerationConfigView,
setDefaultImageGenerationConfig,
unsetDefaultImageGenerationConfig,
deleteImageGenerationConfig,
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
} from "@/lib/configuration/imageConfigurationService";
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
import Message from "@/refresh-components/messages/Message";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { Button, SelectCard, Text } from "@opal/components";
import { Content, CardHeaderLayout } from "@opal/layouts";
import { Hoverable } from "@opal/core";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgCheckSquare,
SvgSettings,
SvgSlash,
SvgUnplug,
} from "@opal/icons";
import { Button, Text } from "@opal/components";
import { SvgSlash, SvgUnplug } from "@opal/icons";
import { markdown } from "@opal/utils";
import { getImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms";
const NO_DEFAULT_VALUE = "__none__";
const STATUS_TO_STATE = {
disconnected: "empty",
connected: "filled",
selected: "selected",
} as const;
export default function ImageGenerationContent() {
const {
data: llmProviderResponse,
@@ -212,13 +198,16 @@ export default function ImageGenerationContent() {
return (
<>
<div className="flex flex-col gap-4">
<Content
title="Image Generation Model"
description="Select a model to generate images in chat."
sizePreset="main-content"
variant="section"
/>
<div className="flex flex-col gap-6">
{/* Section Header */}
<div className="flex flex-col gap-0.5">
<Text font="main-content-emphasis" color="text-05">
Image Generation Model
</Text>
<Text font="secondary-body" color="text-03">
Select a model to generate images in chat.
</Text>
</div>
{connectedProviderIds.size === 0 && (
<Message
@@ -234,111 +223,32 @@ export default function ImageGenerationContent() {
{/* Provider Groups */}
{IMAGE_PROVIDER_GROUPS.map((group) => (
<div key={group.name} className="flex flex-col gap-2">
<Content title={group.name} sizePreset="secondary" variant="body" />
{group.providers.map((provider) => {
const status = getStatus(provider);
const isDisconnected = status === "disconnected";
const isConnected = status === "connected";
const isSelected = status === "selected";
return (
<Hoverable.Root
<Text font="secondary-body" color="text-03">
{group.name}
</Text>
<div className="flex flex-col gap-2">
{group.providers.map((provider) => (
<Select
key={provider.image_provider_id}
group="image-gen/ProviderCard"
>
<SelectCard
variant="select-card"
state={STATUS_TO_STATE[status]}
sizeVariant="lg"
aria-label={`image-gen-provider-${provider.image_provider_id}`}
onClick={
isDisconnected
? () => handleConnect(provider)
: isSelected
? () => handleDeselect(provider)
: undefined
}
>
<CardHeaderLayout
sizePreset="main-ui"
variant="section"
icon={() => (
<ProviderIcon
provider={provider.provider_name}
size={16}
/>
)}
title={provider.title}
description={provider.description}
rightChildren={
isDisconnected ? (
<Button
prominence="tertiary"
rightIcon={SvgArrowExchange}
onClick={(e) => {
e.stopPropagation();
handleConnect(provider);
}}
>
Connect
</Button>
) : isConnected ? (
<Button
prominence="tertiary"
rightIcon={SvgArrowRightCircle}
onClick={(e) => {
e.stopPropagation();
handleSelect(provider);
}}
>
Set as Default
</Button>
) : isSelected ? (
<div className="p-2">
<Content
title="Current Default"
sizePreset="main-ui"
variant="section"
icon={SvgCheckSquare}
/>
</div>
) : undefined
}
bottomRightChildren={
!isDisconnected ? (
<div className="flex flex-row px-1 pb-1">
<Hoverable.Item group="image-gen/ProviderCard">
<Button
icon={SvgUnplug}
tooltip="Disconnect"
aria-label={`Disconnect ${provider.title}`}
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
setDisconnectProvider(provider);
}}
size="md"
/>
</Hoverable.Item>
<Button
icon={SvgSettings}
tooltip="Edit"
aria-label={`Edit ${provider.title}`}
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
handleEdit(provider);
}}
size="md"
/>
</div>
) : undefined
}
/>
</SelectCard>
</Hoverable.Root>
);
})}
aria-label={`image-gen-provider-${provider.image_provider_id}`}
icon={() => (
<ProviderIcon provider={provider.provider_name} size={18} />
)}
title={provider.title}
description={provider.description}
status={getStatus(provider)}
onConnect={() => handleConnect(provider)}
onSelect={() => handleSelect(provider)}
onDeselect={() => handleDeselect(provider)}
onEdit={() => handleEdit(provider)}
onDisconnect={
getStatus(provider) !== "disconnected"
? () => setDisconnectProvider(provider)
: undefined
}
/>
))}
</div>
</div>
))}
</div>
@@ -447,13 +357,13 @@ export default function ImageGenerationContent() {
{activeProvider && (
<modal.Provider>
{getImageGenForm({
modal: modal,
imageProvider: activeProvider,
existingProviders: llmProviders,
existingConfig: editConfig || undefined,
onSuccess: handleModalSuccess,
})}
<ImageGenerationConnectionModal
modal={modal}
imageProvider={activeProvider}
existingProviders={llmProviders}
existingConfig={editConfig || undefined}
onSuccess={handleModalSuccess}
/>
</modal.Provider>
)}
</>

View File

@@ -7,14 +7,14 @@ import { FormField } from "@/refresh-components/form/FormField";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
import {
ImageGenFormBaseProps,
ImageGenFormChildProps,
ImageGenSubmitPayload,
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
} from "./types";
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
import { ImageProvider } from "../constants";
import {
parseAzureTargetUri,
isValidAzureTargetUri,

View File

@@ -10,14 +10,14 @@ import {
createImageGenerationConfig,
updateImageGenerationConfig,
fetchImageGenerationCredentials,
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
} from "@/lib/configuration/imageConfigurationService";
import { APIFormFieldState } from "@/refresh-components/form/types";
import {
ImageGenFormWrapperProps,
ImageGenFormChildProps,
ImageGenSubmitPayload,
FormValues,
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
} from "./types";
import { toast } from "@/hooks/useToast";
export function ImageGenFormWrapper<T extends FormValues>({
@@ -41,6 +41,9 @@ export function ImageGenFormWrapper<T extends FormValues>({
const [errorMessage, setErrorMessage] = useState("");
const [isLoadingCredentials, setIsLoadingCredentials] = useState(false);
// Form reset key for re-initialization
const [formResetKey, setFormResetKey] = useState(0);
// Track merged initial values with fetched credentials
const [mergedInitialValues, setMergedInitialValues] =
useState<T>(initialValues);
@@ -70,6 +73,7 @@ export function ImageGenFormWrapper<T extends FormValues>({
imageProvider
);
setMergedInitialValues((prev) => ({ ...prev, ...credValues }));
setFormResetKey((k) => k + 1);
}
})
.catch((err) => {
@@ -272,6 +276,7 @@ export function ImageGenFormWrapper<T extends FormValues>({
return (
<Formik<T>
key={formResetKey}
initialValues={mergedInitialValues}
onSubmit={handleSubmit}
validationSchema={validationSchema}

View File

@@ -6,14 +6,14 @@ import { FormikField } from "@/refresh-components/form/FormikField";
import { FormField } from "@/refresh-components/form/FormField";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
import {
ImageGenFormBaseProps,
ImageGenFormChildProps,
ImageGenSubmitPayload,
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
} from "./types";
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
import { ImageProvider } from "../constants";
// OpenAI form values - just API key
interface OpenAIFormValues {

View File

@@ -6,14 +6,14 @@ import { FormField } from "@/refresh-components/form/FormField";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputFile from "@/refresh-components/inputs/InputFile";
import InlineExternalLink from "@/refresh-components/InlineExternalLink";
import { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
import { ImageGenFormWrapper } from "./ImageGenFormWrapper";
import {
ImageGenFormBaseProps,
ImageGenFormChildProps,
ImageGenSubmitPayload,
} from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
import { ImageGenerationCredentials } from "@/refresh-pages/admin/ImageGenerationPage/svc";
} from "./types";
import { ImageProvider } from "../constants";
import { ImageGenerationCredentials } from "@/lib/configuration/imageConfigurationService";
const VERTEXAI_PROVIDER_NAME = "vertex_ai";
const VERTEXAI_DEFAULT_LOCATION = "global";

View File

@@ -1,8 +1,8 @@
import React from "react";
import { ImageGenFormBaseProps } from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
import { OpenAIImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/OpenAIImageGenForm";
import { AzureImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/AzureImageGenForm";
import { VertexImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/VertexImageGenForm";
import { ImageGenFormBaseProps } from "./types";
import { OpenAIImageGenForm } from "./OpenAIImageGenForm";
import { AzureImageGenForm } from "./AzureImageGenForm";
import { VertexImageGenForm } from "./VertexImageGenForm";
/**
* Factory function that routes to the correct provider-specific form

View File

@@ -0,0 +1,5 @@
export * from "./types";
export { ImageGenFormWrapper } from "./ImageGenFormWrapper";
export { OpenAIImageGenForm } from "./OpenAIImageGenForm";
export { AzureImageGenForm } from "./AzureImageGenForm";
export { getImageGenForm } from "./getImageGenForm";

View File

@@ -1,10 +1,10 @@
import { FormikProps } from "formik";
import { ImageProvider } from "@/refresh-pages/admin/ImageGenerationPage/constants";
import { ImageProvider } from "../constants";
import { LLMProviderView } from "@/interfaces/llm";
import {
ImageGenerationConfigView,
ImageGenerationCredentials,
} from "@/refresh-pages/admin/ImageGenerationPage/svc";
} from "@/lib/configuration/imageConfigurationService";
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
import { APIFormFieldState } from "@/refresh-components/form/types";

View File

@@ -1 +1,22 @@
export { default } from "@/refresh-pages/admin/ImageGenerationPage";
"use client";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import ImageGenerationContent from "./ImageGenerationContent";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
const route = ADMIN_ROUTES.IMAGE_GENERATION;
export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title={route.title}
description="Settings for in-chat image generation."
/>
<SettingsLayouts.Body>
<ImageGenerationContent />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,129 @@
import React from "react";
import type { Meta, StoryObj } from "@storybook/react";
import Select from "./Select";
import { SvgSettings, SvgFolder, SvgSearch } from "@opal/icons";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
const meta: Meta<typeof Select> = {
title: "refresh-components/cards/Select",
component: Select,
tags: ["autodocs"],
decorators: [
(Story) => (
<TooltipPrimitive.Provider>
<div style={{ maxWidth: 500 }}>
<Story />
</div>
</TooltipPrimitive.Provider>
),
],
};
export default meta;
type Story = StoryObj<typeof Select>;
export const Disconnected: Story = {
args: {
icon: SvgFolder,
title: "Google Drive",
description: "Connect to sync your files",
status: "disconnected",
onConnect: () => {},
},
};
export const Connected: Story = {
args: {
icon: SvgFolder,
title: "Google Drive",
description: "Connected and syncing",
status: "connected",
onSelect: () => {},
onEdit: () => {},
},
};
export const Selected: Story = {
args: {
icon: SvgFolder,
title: "Google Drive",
description: "Currently the default source",
status: "selected",
onDeselect: () => {},
onEdit: () => {},
},
};
export const DisabledState: Story = {
args: {
icon: SvgFolder,
title: "Google Drive",
description: "Not available on this plan",
status: "disconnected",
disabled: true,
onConnect: () => {},
},
};
export const MediumSize: Story = {
args: {
icon: SvgSearch,
title: "Elastic Search",
description: "Search engine connector",
status: "connected",
medium: true,
onSelect: () => {},
},
};
export const CustomLabels: Story = {
args: {
icon: SvgSettings,
title: "Custom LLM",
description: "Your custom model endpoint",
status: "connected",
connectLabel: "Link",
selectLabel: "Make Primary",
selectedLabel: "Primary Model",
onSelect: () => {},
onEdit: () => {},
},
};
export const AllStates: Story = {
render: () => (
<div style={{ display: "flex", flexDirection: "column", gap: 12 }}>
<Select
icon={SvgFolder}
title="Google Drive"
description="Connect to sync your files"
status="disconnected"
onConnect={() => {}}
/>
<Select
icon={SvgSearch}
title="Confluence"
description="Connected and syncing"
status="connected"
onSelect={() => {}}
onEdit={() => {}}
/>
<Select
icon={SvgSettings}
title="Notion"
description="Currently the default source"
status="selected"
onDeselect={() => {}}
onEdit={() => {}}
/>
<Select
icon={SvgFolder}
title="Sharepoint"
description="Not available"
status="disconnected"
disabled
onConnect={() => {}}
/>
</div>
),
};

View File

@@ -0,0 +1,229 @@
"use client";
import React, { useState } from "react";
import type { IconProps } from "@opal/types";
import { cn, noProp } from "@/lib/utils";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import { Button } from "@opal/components";
import SelectButton from "@/refresh-components/buttons/SelectButton";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgCheckSquare,
SvgSettings,
SvgUnplug,
} from "@opal/icons";
const containerClasses = {
selected: "border-action-link-05 bg-action-link-01",
connected: "border-border-01 bg-background-tint-00 hover:shadow-00",
disconnected: "border-border-01 bg-background-neutral-01 hover:shadow-00",
} as const;
export interface SelectProps
extends Omit<React.ComponentPropsWithoutRef<"div">, "title"> {
// Content
icon: React.FunctionComponent<IconProps>;
title: string;
description: string;
// State
status: "disconnected" | "connected" | "selected";
// Actions
onConnect?: () => void;
onSelect?: () => void;
onDeselect?: () => void;
onEdit?: () => void;
onDisconnect?: () => void;
// Labels (customizable)
connectLabel?: string;
selectLabel?: string;
selectedLabel?: string;
// Size
large?: boolean;
medium?: boolean;
// Optional
className?: string;
disabled?: boolean;
}
export default function Select({
icon: Icon,
title,
description,
status,
onConnect,
onSelect,
onDeselect,
onEdit,
onDisconnect,
connectLabel = "Connect",
selectLabel = "Set as Default",
selectedLabel = "Current Default",
large = true,
medium,
className,
disabled,
...rest
}: SelectProps) {
const sizeClass = medium ? "h-[3.75rem]" : "min-h-[3.75rem] max-h-[5.25rem]";
const containerClass = containerClasses[status];
const [isHovered, setIsHovered] = useState(false);
const isSelected = status === "selected";
const isConnected = status === "connected";
const isDisconnected = status === "disconnected";
const isCardClickable = isDisconnected && onConnect && !disabled;
const handleCardClick = () => {
if (isCardClickable) {
onConnect?.();
}
};
return (
<Disabled disabled={disabled} allowClick>
<div
{...rest}
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
onClick={isCardClickable ? handleCardClick : undefined}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-2 min-w-[17.5rem]",
sizeClass,
containerClass,
isCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors",
className
)}
>
{/* Left section - Icon, Title, Description */}
<div className="flex flex-1 items-start gap-1 p-1">
<div className="flex size-5 items-center justify-center px-0.5 shrink-0">
<Icon
className={cn(
"size-4",
isSelected ? "text-action-text-link-05" : "text-text-02"
)}
/>
</div>
<div className="flex flex-col gap-0.5">
<Text mainUiAction text05>
{title}
</Text>
<Text secondaryBody text03>
{description}
</Text>
</div>
</div>
{/* Right section - Actions */}
<div className="flex flex-col h-full items-end justify-between gap-1">
{/* Disconnected: Show Connect button */}
{isDisconnected && (
<Disabled disabled={disabled || !onConnect}>
<Button
prominence="tertiary"
onClick={noProp(onConnect)}
rightIcon={SvgArrowExchange}
>
{connectLabel}
</Button>
</Disabled>
)}
{/* Connected: Show select icon + settings icon */}
{isConnected && (
<>
<Disabled disabled={disabled || !onSelect}>
<SelectButton
action
folded
transient={isHovered}
onClick={onSelect}
rightIcon={SvgArrowRightCircle}
>
{selectLabel}
</SelectButton>
</Disabled>
<div className="flex px-1 gap-1">
{onDisconnect && (
<Disabled disabled={disabled}>
<Button
icon={SvgUnplug}
tooltip="Disconnect"
prominence="tertiary"
size="sm"
onClick={noProp(onDisconnect)}
aria-label={`Disconnect ${title}`}
/>
</Disabled>
)}
{onEdit && (
<Disabled disabled={disabled}>
<Button
icon={SvgSettings}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={noProp(onEdit)}
aria-label={`Edit ${title}`}
/>
</Disabled>
)}
</div>
</>
)}
{/* Selected: Show "Current Default" label + settings icon */}
{isSelected && (
<>
<Disabled disabled={disabled}>
<SelectButton
action
engaged
onClick={onDeselect}
leftIcon={SvgCheckSquare}
>
{selectedLabel}
</SelectButton>
</Disabled>
<div className="flex px-1 gap-1">
{onDisconnect && (
<Disabled disabled={disabled}>
<Button
icon={SvgUnplug}
tooltip="Disconnect"
prominence="tertiary"
size="sm"
onClick={noProp(onDisconnect)}
aria-label={`Disconnect ${title}`}
/>
</Disabled>
)}
{onEdit && (
<Disabled disabled={disabled}>
<Button
icon={SvgSettings}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={noProp(onEdit)}
aria-label={`Edit ${title}`}
/>
</Disabled>
)}
</div>
</>
)}
</div>
</div>
</Disabled>
);
}

View File

@@ -1,2 +1,4 @@
export { default as Card } from "./Card";
export type { CardProps } from "./Card";
export { default as Select } from "./Select";
export type { SelectProps } from "./Select";

View File

@@ -1,7 +1,8 @@
"use client";
import { useState } from "react";
import React, { useState } from "react";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Card, type CardProps } from "@/refresh-components/cards";
import {
SvgArrowExchange,
SvgCheckCircle,
@@ -12,21 +13,47 @@ import {
} from "@opal/icons";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { Section } from "@/layouts/general-layouts";
import { Button, SelectCard } from "@opal/components";
import { CardHeaderLayout } from "@opal/layouts";
import { Disabled, Hoverable } from "@opal/core";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import useCodeInterpreter from "@/hooks/useCodeInterpreter";
import { updateCodeInterpreter } from "@/refresh-pages/admin/CodeInterpreterPage/svc";
import { updateCodeInterpreter } from "@/lib/admin/code-interpreter/svc";
import { ContentAction } from "@opal/layouts";
import { toast } from "@/hooks/useToast";
const route = ADMIN_ROUTES.CODE_INTERPRETER;
// ---------------------------------------------------------------------------
// Sub-components
// ---------------------------------------------------------------------------
interface CodeInterpreterCardProps {
variant?: CardProps["variant"];
title: string;
middleText?: string;
strikethrough?: boolean;
rightContent: React.ReactNode;
}
function CodeInterpreterCard({
variant,
title,
middleText,
strikethrough,
rightContent,
}: CodeInterpreterCardProps) {
return (
// TODO (@raunakab): Allow Content to accept strikethrough and middleText
<Card variant={variant} padding={0.5}>
<ContentAction
icon={SvgTerminal}
title={middleText ? `${title} ${middleText}` : title}
description="Built-in Python runtime"
variant="section"
sizePreset="main-ui"
rightChildren={rightContent}
/>
</Card>
);
}
function CheckingStatus() {
return (
@@ -75,9 +102,46 @@ function ConnectionStatus({ healthy, isLoading }: ConnectionStatusProps) {
);
}
// ---------------------------------------------------------------------------
// Page
// ---------------------------------------------------------------------------
interface ActionButtonsProps {
onDisconnect: () => void;
onRefresh: () => void;
disabled?: boolean;
}
function ActionButtons({
onDisconnect,
onRefresh,
disabled,
}: ActionButtonsProps) {
return (
<Section
flexDirection="row"
justifyContent="end"
alignItems="center"
gap={0.25}
padding={0.25}
>
<Disabled disabled={disabled}>
<Button
prominence="tertiary"
size="sm"
icon={SvgUnplug}
onClick={onDisconnect}
tooltip="Disconnect"
/>
</Disabled>
<Disabled disabled={disabled}>
<Button
prominence="tertiary"
size="sm"
icon={SvgRefreshCw}
onClick={onRefresh}
tooltip="Refresh"
/>
</Disabled>
</Section>
);
}
export default function CodeInterpreterPage() {
const { isHealthy, isEnabled, isLoading, refetch } = useCodeInterpreter();
@@ -111,83 +175,49 @@ export default function CodeInterpreterPage() {
<SettingsLayouts.Body>
{isEnabled || isLoading ? (
<Hoverable.Root group="code-interpreter/Card">
<SelectCard variant="select-card" state="filled" sizeVariant="lg">
<CardHeaderLayout
sizePreset="main-ui"
variant="section"
icon={SvgTerminal}
title="Code Interpreter"
description="Built-in Python runtime"
rightChildren={
<ConnectionStatus healthy={isHealthy} isLoading={isLoading} />
}
bottomRightChildren={
<Section
flexDirection="row"
justifyContent="end"
alignItems="center"
gap={0.25}
padding={0.25}
>
<Disabled disabled={isLoading}>
<Hoverable.Item group="code-interpreter/Card">
<Button
prominence="tertiary"
size="sm"
icon={SvgUnplug}
onClick={() => setShowDisconnectModal(true)}
tooltip="Disconnect"
/>
</Hoverable.Item>
</Disabled>
<Disabled disabled={isLoading}>
<Button
prominence="tertiary"
size="sm"
icon={SvgRefreshCw}
onClick={refetch}
tooltip="Refresh"
/>
</Disabled>
</Section>
}
/>
</SelectCard>
</Hoverable.Root>
<CodeInterpreterCard
title="Code Interpreter"
variant={isHealthy ? "primary" : "secondary"}
strikethrough={!isHealthy}
rightContent={
<Section
flexDirection="column"
justifyContent="center"
alignItems="end"
gap={0}
padding={0}
>
<ConnectionStatus healthy={isHealthy} isLoading={isLoading} />
<ActionButtons
onDisconnect={() => setShowDisconnectModal(true)}
onRefresh={refetch}
disabled={isLoading}
/>
</Section>
}
/>
) : (
<SelectCard
variant="select-card"
state="empty"
sizeVariant="lg"
onClick={() => handleToggle(true)}
>
<CardHeaderLayout
sizePreset="main-ui"
variant="section"
icon={SvgTerminal}
title="Code Interpreter (Disconnected)"
description="Built-in Python runtime"
rightChildren={
<Section flexDirection="row" alignItems="center" padding={0.5}>
{isReconnecting ? (
<CheckingStatus />
) : (
<Button
prominence="tertiary"
rightIcon={SvgArrowExchange}
onClick={(e) => {
e.stopPropagation();
handleToggle(true);
}}
>
Reconnect
</Button>
)}
</Section>
}
/>
</SelectCard>
<CodeInterpreterCard
variant="secondary"
title="Code Interpreter"
middleText="(Disconnected)"
strikethrough={true}
rightContent={
<Section flexDirection="row" alignItems="center" padding={0.5}>
{isReconnecting ? (
<CheckingStatus />
) : (
<Button
prominence="tertiary"
rightIcon={SvgArrowExchange}
onClick={() => handleToggle(true)}
>
Reconnect
</Button>
)}
</Section>
}
/>
)}
</SettingsLayouts.Body>

View File

@@ -5,6 +5,7 @@ import { useRouter } from "next/navigation";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { toast } from "@/hooks/useToast";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import HooksContent from "./HooksContent";
@@ -14,15 +15,20 @@ const route = ADMIN_ROUTES.HOOKS;
export default function HooksPage() {
const router = useRouter();
const { settings, settingsLoading } = useSettingsContext();
const isEE = usePaidEnterpriseFeaturesEnabled();
useEffect(() => {
if (!settingsLoading && !settings.hooks_enabled) {
if (settingsLoading) return;
if (!isEE) {
toast.info("Hook Extensions require an Enterprise license.");
router.replace("/");
} else if (!settings.hooks_enabled) {
toast.info("Hook Extensions are not enabled for this deployment.");
router.replace("/");
}
}, [settingsLoading, settings.hooks_enabled, router]);
}, [settingsLoading, isEE, settings.hooks_enabled, router]);
if (settingsLoading || !settings.hooks_enabled) {
if (settingsLoading || !isEE || !settings.hooks_enabled) {
return <SimpleLoader />;
}

View File

@@ -1,5 +0,0 @@
export * from "@/refresh-pages/admin/ImageGenerationPage/forms/types";
export { ImageGenFormWrapper } from "@/refresh-pages/admin/ImageGenerationPage/forms/ImageGenFormWrapper";
export { OpenAIImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/OpenAIImageGenForm";
export { AzureImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/AzureImageGenForm";
export { getImageGenForm } from "@/refresh-pages/admin/ImageGenerationPage/forms/getImageGenForm";

View File

@@ -1,23 +0,0 @@
"use client";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import ImageGenerationContent from "@/refresh-pages/admin/ImageGenerationPage/ImageGenerationContent";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
const route = ADMIN_ROUTES.IMAGE_GENERATION;
export default function ImageGenerationPage() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={route.icon}
title={route.title}
description="Settings for in-chat image generation."
separator
/>
<SettingsLayouts.Body>
<ImageGenerationContent />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -7,7 +7,7 @@ import {
IconProps,
OpenAIIcon,
} from "@/components/icons/icons";
import ProviderCard from "@/sections/cards/ProviderCard";
import { Select } from "@/refresh-components/cards";
import Message from "@/refresh-components/messages/Message";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { FetchError } from "@/lib/fetcher";
@@ -488,7 +488,7 @@ export default function VoiceConfigurationPage() {
const Icon = getProviderIcon(model.providerType);
return (
<ProviderCard
<Select
key={`${mode}-${model.id}`}
aria-label={`voice-${mode}-${model.id}`}
icon={Icon}

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +0,0 @@
import type { WebSearchProviderType } from "@/refresh-pages/admin/WebSearchPage/searchProviderUtils";
import type { WebContentProviderType } from "@/refresh-pages/admin/WebSearchPage/contentProviderUtils";
export interface WebSearchProviderView {
id: number;
name: string;
provider_type: WebSearchProviderType;
is_active: boolean;
config: Record<string, string> | null;
has_api_key: boolean;
}
export interface WebContentProviderView {
id: number;
name: string;
provider_type: WebContentProviderType;
is_active: boolean;
config: Record<string, string> | null;
has_api_key: boolean;
}
export interface DisconnectTargetState {
id: number;
label: string;
category: "search" | "content";
providerType: string;
}

View File

@@ -1,160 +0,0 @@
import { CONTENT_PROVIDER_DETAILS } from "@/refresh-pages/admin/WebSearchPage/contentProviderUtils";
import type { WebContentProviderView } from "@/refresh-pages/admin/WebSearchPage/interfaces";
async function parseErrorDetail(
res: Response,
fallback: string
): Promise<string> {
try {
const body = await res.json();
return body?.detail ?? fallback;
} catch {
return fallback;
}
}
export async function activateSearchProvider(
providerId: number
): Promise<void> {
const res = await fetch(
`/api/admin/web-search/search-providers/${providerId}/activate`,
{
method: "POST",
headers: { "Content-Type": "application/json" },
}
);
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to set provider as default.")
);
}
}
export async function deactivateSearchProvider(
providerId: number
): Promise<void> {
const res = await fetch(
`/api/admin/web-search/search-providers/${providerId}/deactivate`,
{
method: "POST",
headers: { "Content-Type": "application/json" },
}
);
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to deactivate provider.")
);
}
}
export async function activateContentProvider(
provider: WebContentProviderView
): Promise<void> {
if (provider.provider_type === "onyx_web_crawler") {
const res = await fetch(
"/api/admin/web-search/content-providers/reset-default",
{
method: "POST",
headers: { "Content-Type": "application/json" },
}
);
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to set crawler as default.")
);
}
} else if (provider.id > 0) {
const res = await fetch(
`/api/admin/web-search/content-providers/${provider.id}/activate`,
{
method: "POST",
headers: { "Content-Type": "application/json" },
}
);
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to set crawler as default.")
);
}
} else {
const payload = {
id: null,
name:
provider.name ||
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.label ||
provider.provider_type,
provider_type: provider.provider_type,
api_key: null,
api_key_changed: false,
config: provider.config ?? null,
activate: true,
};
const res = await fetch("/api/admin/web-search/content-providers", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
});
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to set crawler as default.")
);
}
}
}
export async function deactivateContentProvider(
providerId: number,
providerType: string
): Promise<void> {
const endpoint =
providerType === "onyx_web_crawler" || providerId < 0
? "/api/admin/web-search/content-providers/reset-default"
: `/api/admin/web-search/content-providers/${providerId}/deactivate`;
const res = await fetch(endpoint, {
method: "POST",
headers: { "Content-Type": "application/json" },
});
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to deactivate provider.")
);
}
}
export async function disconnectProvider(
id: number,
category: "search" | "content",
replacementProviderId: string | null
): Promise<void> {
// If a replacement was selected (not "No Default"), activate it first
if (replacementProviderId && replacementProviderId !== "__none__") {
const repId = Number(replacementProviderId);
const activateEndpoint =
category === "search"
? `/api/admin/web-search/search-providers/${repId}/activate`
: `/api/admin/web-search/content-providers/${repId}/activate`;
const activateRes = await fetch(activateEndpoint, {
method: "POST",
headers: { "Content-Type": "application/json" },
});
if (!activateRes.ok) {
throw new Error(
await parseErrorDetail(
activateRes,
"Failed to activate replacement provider."
)
);
}
}
const res = await fetch(`/api/admin/web-search/${category}-providers/${id}`, {
method: "DELETE",
});
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to disconnect provider.")
);
}
}

View File

@@ -1,138 +0,0 @@
"use client";
import type { IconFunctionComponent } from "@opal/types";
import { Button, SelectCard } from "@opal/components";
import { Content, CardHeaderLayout } from "@opal/layouts";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgCheckSquare,
SvgSettings,
SvgUnplug,
} from "@opal/icons";
type ProviderStatus = "disconnected" | "connected" | "selected";
interface ProviderCardProps {
icon: IconFunctionComponent;
title: string;
description: string;
status: ProviderStatus;
onConnect?: () => void;
onSelect?: () => void;
onDeselect?: () => void;
onEdit?: () => void;
onDisconnect?: () => void;
selectedLabel?: string;
"aria-label"?: string;
}
const STATUS_TO_STATE = {
disconnected: "empty",
connected: "filled",
selected: "selected",
} as const;
export default function ProviderCard({
icon,
title,
description,
status,
onConnect,
onSelect,
onDeselect,
onEdit,
onDisconnect,
selectedLabel = "Current Default",
"aria-label": ariaLabel,
}: ProviderCardProps) {
const isDisconnected = status === "disconnected";
const isConnected = status === "connected";
const isSelected = status === "selected";
return (
<SelectCard
variant="select-card"
state={STATUS_TO_STATE[status]}
sizeVariant="lg"
aria-label={ariaLabel}
onClick={isDisconnected && onConnect ? onConnect : undefined}
>
<CardHeaderLayout
sizePreset="main-ui"
variant="section"
icon={icon}
title={title}
description={description}
rightChildren={
isDisconnected && onConnect ? (
<Button
prominence="tertiary"
rightIcon={SvgArrowExchange}
onClick={(e) => {
e.stopPropagation();
onConnect();
}}
>
Connect
</Button>
) : isConnected && onSelect ? (
<Button
prominence="tertiary"
rightIcon={SvgArrowRightCircle}
onClick={(e) => {
e.stopPropagation();
onSelect();
}}
>
Set as Default
</Button>
) : isSelected ? (
<div className="p-2">
<Content
title={selectedLabel}
sizePreset="main-ui"
variant="section"
icon={SvgCheckSquare}
/>
</div>
) : undefined
}
bottomRightChildren={
!isDisconnected ? (
<div className="flex flex-row px-1 pb-1">
{onDisconnect && (
<Button
icon={SvgUnplug}
tooltip="Disconnect"
aria-label={`Disconnect ${title}`}
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
onDisconnect();
}}
size="md"
/>
)}
{onEdit && (
<Button
icon={SvgSettings}
tooltip="Edit"
aria-label={`Edit ${title}`}
prominence="tertiary"
onClick={(e) => {
e.stopPropagation();
onEdit();
}}
size="md"
/>
)}
</div>
) : undefined
}
/>
</SelectCard>
);
}
export type { ProviderCardProps, ProviderStatus };

View File

@@ -206,7 +206,8 @@ export default function AdminSidebar({ enableCloudSS }: AdminSidebarProps) {
(billingData && hasActiveSubscription(billingData)) ||
licenseData?.has_license
);
const hooksEnabled = settings?.settings.hooks_enabled ?? false;
const hooksEnabled =
enableEnterprise && (settings?.settings.hooks_enabled ?? false);
const allItems = buildItems(
isCurator,

View File

@@ -78,8 +78,7 @@ test.describe("Image Generation Provider Disconnect", () => {
name: "image-gen-disconnect-non-default-before",
});
// Hover to reveal disconnect button, then verify
await card.hover();
// Verify disconnect button exists and is enabled
const disconnectButton = card.getByRole("button", {
name: "Disconnect DALL-E 3",
});
@@ -155,8 +154,7 @@ test.describe("Image Generation Provider Disconnect", () => {
const defaultCard = getProviderCard(page, "openai_gpt_image_1");
await defaultCard.waitFor({ state: "visible", timeout: 10000 });
// Hover to reveal disconnect button
await defaultCard.hover();
// The disconnect button should be visible and enabled
const disconnectButton = defaultCard.getByRole("button", {
name: "Disconnect GPT Image 1",
});
@@ -198,7 +196,6 @@ test.describe("Image Generation Provider Disconnect", () => {
const defaultCard = getProviderCard(page, "openai_gpt_image_1");
await defaultCard.waitFor({ state: "visible", timeout: 10000 });
await defaultCard.hover();
const disconnectButton = defaultCard.getByRole("button", {
name: "Disconnect GPT Image 1",
});

View File

@@ -1,14 +1,85 @@
import { test, expect } from "@playwright/test";
import { test, expect, Page, Locator } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
import {
WEB_SEARCH_URL,
FAKE_SEARCH_PROVIDERS,
FAKE_CONTENT_PROVIDERS,
findProviderCard,
mainContainer,
mockWebSearchApis,
} from "./svc";
const WEB_SEARCH_URL = "/admin/configuration/web-search";
const FAKE_SEARCH_PROVIDERS = {
exa: {
id: 1,
name: "Exa",
provider_type: "exa",
is_active: true,
config: null,
has_api_key: true,
},
brave: {
id: 2,
name: "Brave",
provider_type: "brave",
is_active: false,
config: null,
has_api_key: true,
},
};
const FAKE_CONTENT_PROVIDERS = {
firecrawl: {
id: 10,
name: "Firecrawl",
provider_type: "firecrawl",
is_active: true,
config: { base_url: "https://api.firecrawl.dev/v2/scrape" },
has_api_key: true,
},
exa: {
id: 11,
name: "Exa",
provider_type: "exa",
is_active: false,
config: null,
has_api_key: true,
},
};
function findProviderCard(page: Page, providerLabel: string): Locator {
return page
.locator("div.rounded-16")
.filter({ hasText: providerLabel })
.first();
}
function mainContainer(page: Page): Locator {
return page.locator("[data-main-container]");
}
async function mockWebSearchApis(
page: Page,
searchProviders: (typeof FAKE_SEARCH_PROVIDERS)[keyof typeof FAKE_SEARCH_PROVIDERS][],
contentProviders: (typeof FAKE_CONTENT_PROVIDERS)[keyof typeof FAKE_CONTENT_PROVIDERS][]
) {
await page.route(
"**/api/admin/web-search/search-providers",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({ status: 200, json: searchProviders });
} else {
await route.continue();
}
}
);
await page.route(
"**/api/admin/web-search/content-providers",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({ status: 200, json: contentProviders });
} else {
await route.continue();
}
}
);
}
test.describe("Web Search Provider Disconnect", () => {
test.beforeEach(async ({ page }) => {
@@ -36,7 +107,6 @@ test.describe("Web Search Provider Disconnect", () => {
name: "web-search-disconnect-non-active-before",
});
await braveCard.hover();
const disconnectButton = braveCard.getByRole("button", {
name: "Disconnect Brave",
});
@@ -109,7 +179,6 @@ test.describe("Web Search Provider Disconnect", () => {
const exaCard = findProviderCard(page, "Exa");
await exaCard.waitFor({ state: "visible", timeout: 10000 });
await exaCard.hover();
const disconnectButton = exaCard.getByRole("button", {
name: "Disconnect Exa",
});
@@ -150,7 +219,6 @@ test.describe("Web Search Provider Disconnect", () => {
const exaCard = findProviderCard(page, "Exa");
await exaCard.waitFor({ state: "visible", timeout: 10000 });
await exaCard.hover();
const disconnectButton = exaCard.getByRole("button", {
name: "Disconnect Exa",
});
@@ -214,7 +282,6 @@ test.describe("Web Search Provider Disconnect", () => {
const firecrawlCard = findProviderCard(page, "Firecrawl");
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
await firecrawlCard.hover();
const disconnectButton = firecrawlCard.getByRole("button", {
name: "Disconnect Firecrawl",
});
@@ -283,7 +350,6 @@ test.describe("Web Search Provider Disconnect", () => {
const firecrawlCard = findProviderCard(page, "Firecrawl");
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
await firecrawlCard.hover();
const disconnectButton = firecrawlCard.getByRole("button", {
name: "Disconnect Firecrawl",
});

View File

@@ -1,100 +0,0 @@
import type { Page, Locator } from "@playwright/test";
export const WEB_SEARCH_URL = "/admin/configuration/web-search";
export const FAKE_SEARCH_PROVIDERS = {
exa: {
id: 1,
name: "Exa",
provider_type: "exa",
is_active: true,
config: null,
has_api_key: true,
},
brave: {
id: 2,
name: "Brave",
provider_type: "brave",
is_active: false,
config: null,
has_api_key: true,
},
};
export const FAKE_CONTENT_PROVIDERS = {
firecrawl: {
id: 10,
name: "Firecrawl",
provider_type: "firecrawl",
is_active: true,
config: { base_url: "https://api.firecrawl.dev/v2/scrape" },
has_api_key: true,
},
exa: {
id: 11,
name: "Exa",
provider_type: "exa",
is_active: false,
config: null,
has_api_key: true,
},
};
export function findProviderCard(page: Page, providerLabel: string): Locator {
return page
.locator("div.rounded-16")
.filter({ hasText: providerLabel })
.first();
}
export function mainContainer(page: Page): Locator {
return page.locator("[data-main-container]");
}
export async function openProviderModal(
page: Page,
providerLabel: string
): Promise<void> {
const card = findProviderCard(page, providerLabel);
await card.waitFor({ state: "visible", timeout: 10000 });
// First try to find the Connect button
const connectButton = card.getByRole("button", { name: "Connect" });
if (await connectButton.isVisible({ timeout: 1000 }).catch(() => false)) {
await connectButton.click();
return;
}
// If no Connect button, click the Edit icon button to update credentials
const editButton = card.getByRole("button", { name: /^Edit / });
await editButton.waitFor({ state: "visible", timeout: 5000 });
await editButton.click();
}
export async function mockWebSearchApis(
page: Page,
searchProviders: (typeof FAKE_SEARCH_PROVIDERS)[keyof typeof FAKE_SEARCH_PROVIDERS][],
contentProviders: (typeof FAKE_CONTENT_PROVIDERS)[keyof typeof FAKE_CONTENT_PROVIDERS][]
): Promise<void> {
await page.route(
"**/api/admin/web-search/search-providers",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({ status: 200, json: searchProviders });
} else {
await route.continue();
}
}
);
await page.route(
"**/api/admin/web-search/content-providers",
async (route) => {
if (route.request().method() === "GET") {
await route.fulfill({ status: 200, json: contentProviders });
} else {
await route.continue();
}
}
);
}

View File

@@ -1,6 +1,40 @@
import { test, expect } from "@playwright/test";
import { test, expect, Page, Locator } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import { WEB_SEARCH_URL, findProviderCard, openProviderModal } from "./svc";
const WEB_SEARCH_URL = "/admin/configuration/web-search";
// Helper to find a provider card by its label text
async function findProviderCard(
page: Page,
providerLabel: string
): Promise<Locator> {
const card = page
.locator("div.rounded-16")
.filter({ hasText: providerLabel })
.first();
return card;
}
// Helper to open the provider setup modal - clicks Connect if available, otherwise clicks the Edit icon
async function openProviderModal(
page: Page,
providerLabel: string
): Promise<void> {
const card = await findProviderCard(page, providerLabel);
await card.waitFor({ state: "visible", timeout: 10000 });
// First try to find the Connect button
const connectButton = card.getByRole("button", { name: "Connect" });
if (await connectButton.isVisible({ timeout: 1000 }).catch(() => false)) {
await connectButton.click();
return;
}
// If no Connect button, click the Edit icon button to update credentials
const editButton = card.getByRole("button", { name: /^Edit / });
await editButton.waitFor({ state: "visible", timeout: 5000 });
await editButton.click();
}
test.describe("Web Content Provider Configuration", () => {
test.beforeEach(async ({ page }) => {
@@ -64,7 +98,7 @@ test.describe("Web Content Provider Configuration", () => {
await page.waitForLoadState("networkidle");
const firecrawlCard = findProviderCard(page, "Firecrawl");
const firecrawlCard = await findProviderCard(page, "Firecrawl");
await expect(
firecrawlCard.getByRole("button", { name: "Current Crawler" })
).toBeVisible({ timeout: 15000 });
@@ -76,7 +110,7 @@ test.describe("Web Content Provider Configuration", () => {
page,
}) => {
// First, ensure Firecrawl is configured and active
const firecrawlCard = findProviderCard(page, "Firecrawl");
const firecrawlCard = await findProviderCard(page, "Firecrawl");
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
const connectButton = firecrawlCard.getByRole("button", {
@@ -116,7 +150,7 @@ test.describe("Web Content Provider Configuration", () => {
}
// Verify Firecrawl is now the current crawler
const updatedFirecrawlCard = findProviderCard(page, "Firecrawl");
const updatedFirecrawlCard = await findProviderCard(page, "Firecrawl");
await expect(
updatedFirecrawlCard.getByRole("button", { name: "Current Crawler" })
).toBeVisible({ timeout: 15000 });
@@ -126,7 +160,7 @@ test.describe("Web Content Provider Configuration", () => {
);
// Switch to Onyx Web Crawler
const onyxCrawlerCard = findProviderCard(page, "Onyx Web Crawler");
const onyxCrawlerCard = await findProviderCard(page, "Onyx Web Crawler");
await onyxCrawlerCard.waitFor({ state: "visible", timeout: 10000 });
const onyxSetDefault = onyxCrawlerCard.getByRole("button", {

View File

@@ -1,6 +1,42 @@
import { test, expect } from "@playwright/test";
import { test, expect, Page, Locator } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import { WEB_SEARCH_URL, findProviderCard, openProviderModal } from "./svc";
const WEB_SEARCH_URL = "/admin/configuration/web-search";
// Helper to find a provider card by its label text
async function findProviderCard(
page: Page,
providerLabel: string
): Promise<Locator> {
// Find the card containing the provider label - cards are divs with rounded borders
// The label is in a Text component inside the card
const card = page
.locator("div.rounded-16")
.filter({ hasText: providerLabel })
.first();
return card;
}
// Helper to open the provider setup modal - clicks Connect if available, otherwise clicks the Edit icon
async function openProviderModal(
page: Page,
providerLabel: string
): Promise<void> {
const card = await findProviderCard(page, providerLabel);
await card.waitFor({ state: "visible", timeout: 10000 });
// First try to find the Connect button
const connectButton = card.getByRole("button", { name: "Connect" });
if (await connectButton.isVisible({ timeout: 1000 }).catch(() => false)) {
await connectButton.click();
return;
}
// If no Connect button, click the Edit icon button to update credentials
const editButton = card.getByRole("button", { name: /^Edit / });
await editButton.waitFor({ state: "visible", timeout: 5000 });
await editButton.click();
}
test.describe("Web Search Provider Configuration", () => {
test.beforeEach(async ({ page }) => {
@@ -63,7 +99,7 @@ test.describe("Web Search Provider Configuration", () => {
await page.waitForLoadState("networkidle");
// Verify Exa is now the current default - look for "Current Default" button in the Exa card
const exaCard = findProviderCard(page, "Exa");
const exaCard = await findProviderCard(page, "Exa");
await expect(
exaCard.getByRole("button", { name: "Current Default" })
).toBeVisible({ timeout: 15000 });
@@ -129,7 +165,7 @@ test.describe("Web Search Provider Configuration", () => {
await page.waitForLoadState("networkidle");
// Verify Google PSE is now the current default
const googleCard = findProviderCard(page, "Google PSE");
const googleCard = await findProviderCard(page, "Google PSE");
await expect(
googleCard.getByRole("button", { name: "Current Default" })
).toBeVisible({ timeout: 15000 });
@@ -143,7 +179,7 @@ test.describe("Web Search Provider Configuration", () => {
page,
}) => {
// First, configure Google PSE if not already configured
const googleCard = findProviderCard(page, "Google PSE");
const googleCard = await findProviderCard(page, "Google PSE");
await googleCard.waitFor({ state: "visible", timeout: 10000 });
const connectButton = googleCard.getByRole("button", { name: "Connect" });
@@ -178,7 +214,7 @@ test.describe("Web Search Provider Configuration", () => {
);
// Now click the Edit icon button
const updatedGoogleCard = findProviderCard(page, "Google PSE");
const updatedGoogleCard = await findProviderCard(page, "Google PSE");
const editButton = updatedGoogleCard.getByRole("button", {
name: /^Edit /,
});
@@ -220,7 +256,7 @@ test.describe("Web Search Provider Configuration", () => {
await page.waitForLoadState("networkidle");
// Verify Google PSE is still the current default
const finalGoogleCard = findProviderCard(page, "Google PSE");
const finalGoogleCard = await findProviderCard(page, "Google PSE");
await expect(
finalGoogleCard.getByRole("button", { name: "Current Default" })
).toBeVisible({ timeout: 15000 });
@@ -234,7 +270,7 @@ test.describe("Web Search Provider Configuration", () => {
page,
}) => {
// First, configure Google PSE if not already configured
const googleCard = findProviderCard(page, "Google PSE");
const googleCard = await findProviderCard(page, "Google PSE");
await googleCard.waitFor({ state: "visible", timeout: 10000 });
const connectButton = googleCard.getByRole("button", { name: "Connect" });
@@ -269,7 +305,7 @@ test.describe("Web Search Provider Configuration", () => {
);
// Now click the Edit icon button
const updatedGoogleCard = findProviderCard(page, "Google PSE");
const updatedGoogleCard = await findProviderCard(page, "Google PSE");
const editButton = updatedGoogleCard.getByRole("button", {
name: /^Edit /,
});
@@ -347,7 +383,7 @@ test.describe("Web Search Provider Configuration", () => {
await expect(modalDialog).not.toBeVisible({ timeout: 30000 });
await page.waitForLoadState("networkidle");
const braveCard = findProviderCard(page, "Brave");
const braveCard = await findProviderCard(page, "Brave");
await expect(
braveCard.getByRole("button", { name: "Current Default" })
).toBeVisible({ timeout: 15000 });
@@ -367,7 +403,7 @@ test.describe("Web Search Provider Configuration", () => {
test("should switch between configured providers", async ({ page }) => {
// First, configure Exa if needed
const exaCard = findProviderCard(page, "Exa");
const exaCard = await findProviderCard(page, "Exa");
await exaCard.waitFor({ state: "visible", timeout: 10000 });
let connectButton = exaCard.getByRole("button", { name: "Connect" });
@@ -390,7 +426,7 @@ test.describe("Web Search Provider Configuration", () => {
}
// Configure Google PSE if needed
const googleCard = findProviderCard(page, "Google PSE");
const googleCard = await findProviderCard(page, "Google PSE");
await googleCard.waitFor({ state: "visible", timeout: 10000 });
connectButton = googleCard.getByRole("button", { name: "Connect" });