Compare commits

..

5 Commits

Author SHA1 Message Date
Bo-Onyx
f77d5d2d01 feat(hook): integrate query processing hook point 2026-03-20 17:37:01 -07:00
Bo-Onyx
30e7a831a5 feat(hook): Add hook management API (#9513)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 22:53:59 +00:00
Evan Lohn
276261c96d fix: windows installer (#9507) 2026-03-20 22:53:46 +00:00
Bo-Onyx
205f1410e4 chore(hook): Hook executor. (#9467) 2026-03-20 22:47:01 +00:00
Bo-Onyx
a93d154c27 feat(hook): improve on hook point definition (#9522) 2026-03-20 22:20:42 +00:00
62 changed files with 3015 additions and 1116 deletions

View File

@@ -9,12 +9,12 @@ from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.app_configs import HOOK_ENABLED
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.hooks.utils import HOOKS_AVAILABLE
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@@ -362,7 +362,7 @@ if not MULTI_TENANT:
tasks_to_schedule.extend(beat_task_templates)
if not MULTI_TENANT and HOOK_ENABLED:
if HOOKS_AVAILABLE:
tasks_to_schedule.append(
{
"name": "hook-execution-log-cleanup",

View File

@@ -59,6 +59,7 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -68,11 +69,18 @@ from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingPayload
from onyx.hooks.points.query_processing import QueryProcessingResponse
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
@@ -424,6 +432,32 @@ def determine_search_params(
)
def _apply_query_processing_hook(
hook_result: BaseModel | HookSkipped | HookSoftFailed,
message_text: str,
) -> str:
"""Apply the Query Processing hook result to the message text.
Returns the (possibly rewritten) message text, or raises OnyxError with
QUERY_REJECTED if the hook signals rejection (query is null or empty).
HookSkipped and HookSoftFailed are pass-throughs — the original text is
returned unchanged.
"""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return message_text
if not isinstance(hook_result, QueryProcessingResponse):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Expected QueryProcessingResponse from hook, got {type(hook_result).__name__}",
)
if not hook_result.query:
raise OnyxError(
OnyxErrorCode.QUERY_REJECTED,
hook_result.rejection_message or "Your query was rejected.",
)
return hook_result.query
def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
@@ -484,6 +518,19 @@ def handle_stream_message_objects(
persona = chat_session.persona
message_text = new_msg_req.message
# Query Processing hook — runs before the message is saved to DB.
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=QueryProcessingPayload(
query=message_text,
user_email=None if user.is_anonymous else user.email,
chat_session_id=str(chat_session.id),
).model_dump(),
)
message_text = _apply_query_processing_hook(hook_result, message_text)
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)

View File

@@ -2,6 +2,7 @@ import time
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.db.connector_credential_pair import get_connector_credential_pairs
@@ -149,6 +150,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
if DISABLE_VECTOR_DB:
return None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)

View File

@@ -44,6 +44,7 @@ class OnyxErrorCode(Enum):
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
QUERY_REJECTED = ("QUERY_REJECTED", 400)
# ------------------------------------------------------------------
# Not Found (404)
@@ -88,6 +89,7 @@ class OnyxErrorCode(Enum):
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
BAD_GATEWAY = ("BAD_GATEWAY", 502)
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
def __init__(self, code: str, status_code: int) -> None:

View File

@@ -0,0 +1,381 @@
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
Usage (Celery tasks and FastAPI handlers):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
)
if isinstance(result, HookSkipped):
# no active hook configured — continue with original behavior
...
elif isinstance(result, HookSoftFailed):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (spec.response_model)
...
is_reachable update policy
--------------------------
``is_reachable`` on the Hook row is updated selectively — only when the outcome
carries meaningful signal about physical reachability:
NetworkError (DNS, connection refused) → False (cannot reach the server)
HTTP 401 / 403 → False (api_key revoked or invalid)
TimeoutException → None (server may be slow, skip write)
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
Unknown exception → None (no signal, skip write)
Non-JSON / non-dict response → None (server responded, skip write)
Success (2xx, valid dict) → True (confirmed reachable)
None means "leave the current value unchanged" — no DB round-trip is made.
DB session design
-----------------
The executor uses three sessions:
1. Caller's session (db_session) — used only for the hook lookup read. All
needed fields are extracted from the Hook object before the HTTP call, so
the caller's session is not held open during the external HTTP request.
2. Log session — a separate short-lived session opened after the HTTP call
completes to write the HookExecutionLog row on failure. Success runs are
not recorded. Committed independently of everything else.
3. Reachable session — a second short-lived session to update is_reachable on
the Hook. Kept separate from the log session so a concurrent hook deletion
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
prevent the execution log from being written. This update is best-effort.
"""
import json
import time
from typing import Any
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.hook import create_hook_execution_log__no_commit
from onyx.db.hook import get_non_deleted_hook_by_hook_point
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
logger = setup_logger()
class HookSkipped:
"""No active hook configured for this hook point."""
class HookSoftFailed:
"""Hook was called but failed with SOFT fail strategy — continuing."""
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
class _HttpOutcome(BaseModel):
"""Structured result of an HTTP hook call, returned by _process_response."""
is_success: bool
updated_is_reachable: (
bool | None
) # True/False = write to DB, None = unchanged (skip write)
status_code: int | None
error_message: str | None
response_payload: dict[str, Any] | None
def _lookup_hook(
db_session: Session,
hook_point: HookPoint,
) -> Hook | HookSkipped:
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
No HTTP call is made and no DB writes are performed for any HookSkipped path.
There is nothing to log and no reachability information to update.
"""
if not HOOKS_AVAILABLE:
return HookSkipped()
hook = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if hook is None or not hook.is_active:
return HookSkipped()
if not hook.endpoint_url:
return HookSkipped()
return hook
def _process_response(
*,
response: httpx.Response | None,
exc: Exception | None,
timeout: float,
) -> _HttpOutcome:
"""Process the result of an HTTP call and return a structured outcome.
Called after the client.post() try/except. If post() raised, exc is set and
response is None. Otherwise response is set and exc is None. Handles
raise_for_status(), JSON decoding, and the dict shape check.
"""
if exc is not None:
if isinstance(exc, httpx.NetworkError):
msg = f"Hook network error (endpoint unreachable): {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False,
status_code=None,
error_message=msg,
response_payload=None,
)
if isinstance(exc, httpx.TimeoutException):
msg = f"Hook timed out after {timeout}s: {exc}"
logger.warning(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # timeout doesn't indicate unreachability
status_code=None,
error_message=msg,
response_payload=None,
)
msg = f"Hook call failed: {exc}"
logger.exception(msg, exc_info=exc)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # unknown error — don't make assumptions
status_code=None,
error_message=msg,
response_payload=None,
)
if response is None:
raise ValueError(
"exactly one of response or exc must be non-None; both are None"
)
status_code = response.status_code
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
logger.warning(msg, exc_info=e)
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
# so the operator knows to update it. All other HTTP errors keep is_reachable
# as-is (server is up, the request just failed for application reasons).
auth_failed = e.response.status_code in (401, 403)
return _HttpOutcome(
is_success=False,
updated_is_reachable=False if auth_failed else None,
status_code=status_code,
error_message=msg,
response_payload=None,
)
try:
response_payload = response.json()
except (json.JSONDecodeError, httpx.DecodingError) as e:
msg = f"Hook returned non-JSON response: {e}"
logger.warning(msg, exc_info=e)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
if not isinstance(response_payload, dict):
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
logger.warning(msg)
return _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=status_code,
error_message=msg,
response_payload=None,
)
return _HttpOutcome(
is_success=True,
updated_is_reachable=True,
status_code=status_code,
error_message=None,
response_payload=response_payload,
)
def _persist_result(
*,
hook_id: int,
outcome: _HttpOutcome,
duration_ms: int,
) -> None:
"""Write the execution log on failure and optionally update is_reachable, each
in its own session so a failure in one does not affect the other."""
# Only write the execution log on failure — success runs are not recorded.
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
# deleted between the initial lookup and here).
if not outcome.is_success:
try:
with get_session_with_current_tenant() as log_session:
create_hook_execution_log__no_commit(
db_session=log_session,
hook_id=hook_id,
is_success=False,
error_message=outcome.error_message,
status_code=outcome.status_code,
duration_ms=duration_ms,
)
log_session.commit()
except Exception:
logger.exception(
f"Failed to persist hook execution log for hook_id={hook_id}"
)
# Update is_reachable separately — best-effort, non-critical.
# None means the value is unchanged (set by the caller to skip the no-op write).
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
# concurrently deleted, so keep this isolated from the log write above.
if outcome.updated_is_reachable is not None:
try:
with get_session_with_current_tenant() as reachable_session:
update_hook__no_commit(
db_session=reachable_session,
hook_id=hook_id,
is_reachable=outcome.updated_is_reachable,
)
reachable_session.commit()
except Exception:
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
) -> BaseModel | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
"active hooks without an endpoint_url must be rejected by _lookup_hook"
)
start = time.monotonic()
response: httpx.Response | None = None
exc: Exception | None = None
try:
api_key: str | None = (
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
)
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(timeout=timeout, follow_redirects=False) as client:
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against the spec's response_model.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: BaseModel | None = None
if outcome.is_success and outcome.response_payload is not None:
spec = get_hook_point_spec(hook.hook_point)
try:
validated_model = spec.response_model.model_validate(
outcome.response_payload
)
except ValidationError as e:
msg = f"Hook response failed validation against {spec.response_model.__name__}: {e}"
logger.warning(msg)
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
outcome = outcome.model_copy(update={"updated_is_reachable": None})
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
if not outcome.is_success:
if fail_strategy == HookFailStrategy.HARD:
raise OnyxError(
OnyxErrorCode.HOOK_EXECUTION_FAILED,
outcome.error_message or "Hook execution failed.",
)
logger.warning(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
raise ValueError(
f"validated_model is None for successful hook call (hook_id={hook_id})"
)
return validated_model
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> BaseModel | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
try:
return _execute_hook_inner(hook, payload)
except Exception:
if hook.fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook.id}"
)
return HookSoftFailed()
raise

View File

@@ -42,12 +42,8 @@ class HookUpdateRequest(BaseModel):
name: str | None = None
endpoint_url: str | None = None
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = (
None # if None in model_fields_set, reset to spec default
)
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None in model_fields_set, reset to spec default
fail_strategy: HookFailStrategy | None = None
timeout_seconds: float | None = Field(default=None, gt=0)
@model_validator(mode="after")
def require_at_least_one_field(self) -> "HookUpdateRequest":
@@ -60,6 +56,14 @@ class HookUpdateRequest(BaseModel):
and not (self.endpoint_url or "").strip()
):
raise ValueError("endpoint_url cannot be cleared.")
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
raise ValueError(
"fail_strategy cannot be null; omit the field to leave it unchanged."
)
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
raise ValueError(
"timeout_seconds cannot be null; omit the field to leave it unchanged."
)
return self
@@ -90,38 +94,28 @@ class HookResponse(BaseModel):
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool
is_reachable: bool | None
creator_email: str | None
created_at: datetime
updated_at: datetime
class HookValidateStatus(str, Enum):
passed = "passed" # server responded (any status except 401/403)
auth_failed = "auth_failed" # server responded with 401 or 403
timeout = (
"timeout" # TCP connected, but read/write timed out (server exists but slow)
)
cannot_connect = "cannot_connect" # could not connect to the server
class HookValidateResponse(BaseModel):
success: bool
status: HookValidateStatus
error_message: str | None = None
# ---------------------------------------------------------------------------
# Health models
# ---------------------------------------------------------------------------
class HookHealthStatus(str, Enum):
healthy = "healthy" # green — reachable, no failures in last 1h
degraded = "degraded" # yellow — reachable, failures in last 1h
unreachable = "unreachable" # red — is_reachable=false or null
class HookFailureRecord(BaseModel):
class HookExecutionRecord(BaseModel):
error_message: str | None = None
status_code: int | None = None
duration_ms: int | None = None
created_at: datetime
class HookHealthResponse(BaseModel):
status: HookHealthStatus
recent_failures: list[HookFailureRecord] = Field(
default_factory=list,
description="Last 10 failures, newest first",
max_length=10,
)

View File

@@ -1,6 +1,7 @@
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import ClassVar
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
@@ -13,22 +14,25 @@ _REQUIRED_ATTRS = (
"default_timeout_seconds",
"fail_hard_description",
"default_fail_strategy",
"payload_model",
"response_model",
)
class HookPointSpec(ABC):
class HookPointSpec:
"""Static metadata and contract for a pipeline hook point.
This is NOT a regular class meant for direct instantiation by callers.
Each concrete subclass represents exactly one hook point and is instantiated
once at startup, registered in onyx.hooks.registry._REGISTRY. No caller
should ever create instances directly — use get_hook_point_spec() or
get_all_specs() from the registry instead.
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
get_hook_point_spec() or get_all_specs() from the registry over direct
instantiation.
Each hook point is a concrete subclass of this class. Onyx engineers
own these definitions — customers never touch this code.
Subclasses must define all attributes as class-level constants.
payload_model and response_model must be Pydantic BaseModel subclasses;
input_schema and output_schema are derived from them automatically.
"""
hook_point: HookPoint
@@ -39,21 +43,32 @@ class HookPointSpec(ABC):
default_fail_strategy: HookFailStrategy
docs_url: str | None = None
payload_model: ClassVar[type[BaseModel]]
response_model: ClassVar[type[BaseModel]]
# Computed once at class definition time from payload_model / response_model.
input_schema: ClassVar[dict[str, Any]]
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Raises TypeError at import time if any required attribute is missing or if
payload_model / response_model are not Pydantic BaseModel subclasses.
input_schema and output_schema are derived automatically from the models.
"""
super().__init_subclass__(**kwargs)
# Skip intermediate abstract subclasses — they may still be partially defined.
if getattr(cls, "__abstractmethods__", None):
return
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
if missing:
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
@property
@abstractmethod
def input_schema(self) -> dict[str, Any]:
"""JSON schema describing the request payload sent to the customer's endpoint."""
@property
@abstractmethod
def output_schema(self) -> dict[str, Any]:
"""JSON schema describing the expected response from the customer's endpoint."""
for attr in ("payload_model", "response_model"):
val = getattr(cls, attr, None)
if val is None or not (
isinstance(val, type) and issubclass(val, BaseModel)
):
raise TypeError(
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
)
cls.input_schema = cls.payload_model.model_json_schema()
cls.output_schema = cls.response_model.model_json_schema()

View File

@@ -1,10 +1,19 @@
from typing import Any
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
# TODO(@Bo-Onyx): define payload and response fields
class DocumentIngestionPayload(BaseModel):
pass
class DocumentIngestionResponse(BaseModel):
pass
class DocumentIngestionSpec(HookPointSpec):
"""Hook point that runs during document ingestion.
@@ -18,12 +27,5 @@ class DocumentIngestionSpec(HookPointSpec):
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
@property
def input_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define input schema
return {"type": "object", "properties": {}}
@property
def output_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define output schema
return {"type": "object", "properties": {}}
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse

View File

@@ -1,10 +1,39 @@
from typing import Any
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class QueryProcessingPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
query: str = Field(description="The raw query string exactly as the user typed it.")
user_email: str | None = Field(
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
)
class QueryProcessingResponse(BaseModel):
# Intentionally permissive — customer endpoints may return extra fields.
query: str | None = Field(
default=None,
description=(
"The query to use in the pipeline. "
"Null, empty string, or absent = reject the query."
),
)
rejection_message: str | None = Field(
default=None,
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
)
class QueryProcessingSpec(HookPointSpec):
"""Hook point that runs on every user query before it enters the pipeline.
@@ -37,47 +66,5 @@ class QueryProcessingSpec(HookPointSpec):
)
default_fail_strategy = HookFailStrategy.HARD
@property
def input_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The raw query string exactly as the user typed it.",
},
"user_email": {
"type": ["string", "null"],
"description": "Email of the user submitting the query, or null if unauthenticated.",
},
"chat_session_id": {
"type": "string",
"description": "UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires.",
},
},
"required": ["query", "user_email", "chat_session_id"],
"additionalProperties": False,
}
@property
def output_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": ["string", "null"],
"description": (
"The (optionally modified) query to use. "
"Set to null to reject the query."
),
},
"rejection_message": {
"type": ["string", "null"],
"description": (
"Message shown to the user when query is null. "
"Falls back to a generic message if not provided."
),
},
},
"required": ["query"],
}
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse

View File

@@ -0,0 +1,5 @@
from onyx.configs.app_configs import HOOK_ENABLED
from shared_configs.configs import MULTI_TENANT
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT

View File

@@ -77,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
)
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.hierarchy.api import router as hierarchy_router
from onyx.server.features.hooks.api import router as hook_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
@@ -453,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
register_onyx_exception_handlers(application)
include_router_with_global_prefix_prepended(application, hook_router)
include_router_with_global_prefix_prepended(application, password_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)

View File

@@ -0,0 +1,453 @@
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.hook import create_hook__no_commit
from onyx.db.hook import delete_hook__no_commit
from onyx.db.hook import get_hook_by_id
from onyx.db.hook import get_hook_execution_logs
from onyx.db.hook import get_hooks
from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.api_dependencies import require_hook_enabled
from onyx.hooks.models import HookCreateRequest
from onyx.hooks.models import HookExecutionRecord
from onyx.hooks.models import HookPointMetaResponse
from onyx.hooks.models import HookResponse
from onyx.hooks.models import HookUpdateRequest
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.hooks.registry import get_all_specs
from onyx.hooks.registry import get_hook_point_spec
from onyx.utils.logger import setup_logger
from onyx.utils.url import SSRFException
from onyx.utils.url import validate_outbound_http_url
logger = setup_logger()
# ---------------------------------------------------------------------------
# SSRF protection
# ---------------------------------------------------------------------------
def _check_ssrf_safety(endpoint_url: str) -> None:
"""Raise OnyxError if endpoint_url could be used for SSRF.
Delegates to validate_outbound_http_url with https_only=True.
"""
try:
validate_outbound_http_url(endpoint_url, https_only=True)
except (SSRFException, ValueError) as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
return HookResponse(
id=hook.id,
name=hook.name,
hook_point=hook.hook_point,
endpoint_url=hook.endpoint_url,
fail_strategy=hook.fail_strategy,
timeout_seconds=hook.timeout_seconds,
is_active=hook.is_active,
is_reachable=hook.is_reachable,
creator_email=(
creator_email
if creator_email is not None
else (hook.creator.email if hook.creator else None)
),
created_at=hook.created_at,
updated_at=hook.updated_at,
)
def _get_hook_or_404(
db_session: Session,
hook_id: int,
include_creator: bool = False,
) -> Hook:
hook = get_hook_by_id(
db_session=db_session,
hook_id=hook_id,
include_creator=include_creator,
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
return hook
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
"""Raise an appropriate OnyxError for a non-passed validation result."""
if validation.status == HookValidateStatus.auth_failed:
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
if validation.status == HookValidateStatus.timeout:
raise OnyxError(
OnyxErrorCode.GATEWAY_TIMEOUT,
f"Endpoint validation failed: {validation.error_message}",
)
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Endpoint validation failed: {validation.error_message}",
)
def _validate_endpoint(
endpoint_url: str,
api_key: str | None,
timeout_seconds: float,
) -> HookValidateResponse:
"""Check whether endpoint_url is reachable by sending an empty POST request.
We use POST since hook endpoints expect POST requests. The server will typically
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
the server is up and routable. A 401/403 response returns auth_failed
(not reachable — indicates the api_key is invalid).
Timeout handling:
- ConnectTimeout: TCP handshake never completed → cannot_connect.
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
(operator should consider increasing timeout_seconds).
- All other exceptions → cannot_connect.
"""
_check_ssrf_safety(endpoint_url)
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try:
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
response = client.post(endpoint_url, headers=headers)
if response.status_code in (401, 403):
return HookValidateResponse(
status=HookValidateStatus.auth_failed,
error_message=f"Authentication failed (HTTP {response.status_code})",
)
return HookValidateResponse(status=HookValidateStatus.passed)
except httpx.TimeoutException as exc:
# ConnectTimeout: TCP handshake never completed → cannot_connect.
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
if isinstance(exc, httpx.ConnectTimeout):
logger.warning(
"Hook endpoint validation: connect timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
logger.warning(
"Hook endpoint validation: read/write timeout for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.timeout,
error_message="Endpoint timed out — consider increasing timeout_seconds.",
)
except Exception as exc:
logger.warning(
"Hook endpoint validation: connection error for %s",
endpoint_url,
exc_info=exc,
)
return HookValidateResponse(
status=HookValidateStatus.cannot_connect, error_message=str(exc)
)
# ---------------------------------------------------------------------------
# Routers
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/admin/hooks")
# ---------------------------------------------------------------------------
# Hook endpoints
# ---------------------------------------------------------------------------
@router.get("/specs")
def get_hook_point_specs(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
) -> list[HookPointMetaResponse]:
return [
HookPointMetaResponse(
hook_point=spec.hook_point,
display_name=spec.display_name,
description=spec.description,
docs_url=spec.docs_url,
input_schema=spec.input_schema,
output_schema=spec.output_schema,
default_timeout_seconds=spec.default_timeout_seconds,
default_fail_strategy=spec.default_fail_strategy,
fail_hard_description=spec.fail_hard_description,
)
for spec in get_all_specs()
]
@router.get("")
def list_hooks(
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookResponse]:
hooks = get_hooks(db_session=db_session, include_creator=True)
return [_hook_to_response(h) for h in hooks]
@router.post("")
def create_hook(
req: HookCreateRequest,
user: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Create a new hook. The endpoint is validated before persisting — creation fails if
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
use POST /{hook_id}/activate once ready to receive traffic."""
spec = get_hook_point_spec(req.hook_point)
api_key = req.api_key.get_secret_value() if req.api_key else None
validation = _validate_endpoint(
endpoint_url=req.endpoint_url,
api_key=api_key,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
hook = create_hook__no_commit(
db_session=db_session,
name=req.name,
hook_point=req.hook_point,
endpoint_url=req.endpoint_url,
api_key=api_key,
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
creator_id=user.id,
)
hook.is_reachable = True
db_session.commit()
return _hook_to_response(hook, creator_email=user.email)
@router.get("/{hook_id}")
def get_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
return _hook_to_response(hook)
@router.patch("/{hook_id}")
def update_hook(
hook_id: int,
req: HookUpdateRequest,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
endpoint is re-validated using the effective values. For active hooks the update is
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
the update goes through regardless and is_reachable is updated to reflect the result.
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
"""
# api_key: UNSET = no change, None = clear, value = update
api_key: str | None | UnsetType
if "api_key" not in req.model_fields_set:
api_key = UNSET
elif req.api_key is None:
api_key = None
else:
api_key = req.api_key.get_secret_value()
endpoint_url_changing = "endpoint_url" in req.model_fields_set
api_key_changing = not isinstance(api_key, UnsetType)
timeout_changing = "timeout_seconds" in req.model_fields_set
validated_is_reachable: bool | None = None
if endpoint_url_changing or api_key_changing or timeout_changing:
existing = _get_hook_or_404(db_session, hook_id)
effective_url: str = (
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
)
effective_api_key: str | None = (
(api_key if not isinstance(api_key, UnsetType) else None)
if api_key_changing
else (
existing.api_key.get_value(apply_mask=False)
if existing.api_key
else None
)
)
effective_timeout: float = (
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
)
validation = _validate_endpoint(
endpoint_url=effective_url,
api_key=effective_api_key,
timeout_seconds=effective_timeout,
)
if existing.is_active and validation.status != HookValidateStatus.passed:
_raise_for_validation_failure(validation)
validated_is_reachable = validation.status == HookValidateStatus.passed
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
name=req.name,
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
api_key=api_key,
fail_strategy=req.fail_strategy,
timeout_seconds=req.timeout_seconds,
is_reachable=validated_is_reachable,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.delete("/{hook_id}")
def delete_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> None:
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
db_session.commit()
@router.post("/{hook_id}/activate")
def activate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
if validation.status != HookValidateStatus.passed:
# Persist is_reachable=False in a separate session so the request
# session has no commits on the failure path and the transaction
# boundary stays clean.
if hook.is_reachable is not False:
with get_session_with_current_tenant() as side_session:
update_hook__no_commit(
db_session=side_session, hook_id=hook_id, is_reachable=False
)
side_session.commit()
_raise_for_validation_failure(validation)
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=True,
is_reachable=True,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
@router.post("/{hook_id}/validate")
def validate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookValidateResponse:
hook = _get_hook_or_404(db_session, hook_id)
if not hook.endpoint_url:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
)
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
validation = _validate_endpoint(
endpoint_url=hook.endpoint_url,
api_key=api_key,
timeout_seconds=hook.timeout_seconds,
)
validation_passed = validation.status == HookValidateStatus.passed
if hook.is_reachable != validation_passed:
update_hook__no_commit(
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
)
db_session.commit()
return validation
@router.post("/{hook_id}/deactivate")
def deactivate_hook(
hook_id: int,
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> HookResponse:
hook = update_hook__no_commit(
db_session=db_session,
hook_id=hook_id,
is_active=False,
include_creator=True,
)
db_session.commit()
return _hook_to_response(hook)
# ---------------------------------------------------------------------------
# Execution log endpoints
# ---------------------------------------------------------------------------
@router.get("/{hook_id}/execution-logs")
def list_hook_execution_logs(
hook_id: int,
limit: int = Query(default=10, ge=1, le=100),
_: User = Depends(current_admin_user),
_hook_enabled: None = Depends(require_hook_enabled),
db_session: Session = Depends(get_session),
) -> list[HookExecutionRecord]:
_get_hook_or_404(db_session, hook_id)
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
return [
HookExecutionRecord(
error_message=log.error_message,
status_code=log.status_code,
duration_ms=log.duration_ms,
created_at=log.created_at,
)
for log in logs
]

View File

@@ -140,10 +140,20 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
return validated_ip, hostname, port
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
def validate_outbound_http_url(
url: str,
*,
allow_private_network: bool = False,
https_only: bool = False,
) -> str:
"""
Validate a URL that will be used by backend outbound HTTP calls.
Args:
url: The URL to validate.
allow_private_network: If True, skip private/reserved IP checks.
https_only: If True, reject http:// URLs (only https:// is allowed).
Returns:
A normalized URL string with surrounding whitespace removed.
@@ -157,7 +167,12 @@ def validate_outbound_http_url(url: str, *, allow_private_network: bool = False)
parsed = urlparse(normalized_url)
if parsed.scheme not in ("http", "https"):
if https_only:
if parsed.scheme != "https":
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
)
elif parsed.scheme not in ("http", "https"):
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
)

View File

@@ -1,4 +1,12 @@
import pytest
from onyx.chat.process_message import _apply_query_processing_hook
from onyx.chat.process_message import remove_answer_citations
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
@@ -32,3 +40,83 @@ def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None
remove_answer_citations(answer)
== "See [reference](https://example.com/Function_(mathematics)) for context."
)
# ---------------------------------------------------------------------------
# Query Processing hook response handling (_apply_query_processing_hook)
# ---------------------------------------------------------------------------
def test_wrong_model_type_raises_internal_error() -> None:
"""If the executor ever returns an unexpected BaseModel type, raise INTERNAL_ERROR
rather than an AssertionError or AttributeError."""
from pydantic import BaseModel as PydanticBaseModel
class _OtherModel(PydanticBaseModel):
pass
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(_OtherModel(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
def test_hook_skipped_leaves_message_text_unchanged() -> None:
result = _apply_query_processing_hook(HookSkipped(), "original query")
assert result == "original query"
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
result = _apply_query_processing_hook(HookSoftFailed(), "original query")
assert result == "original query"
def test_null_query_raises_query_rejected() -> None:
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=None), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_empty_string_query_raises_query_rejected() -> None:
"""Empty string is falsy — must be treated as rejection, same as None."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=""), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_absent_query_field_raises_query_rejected() -> None:
"""query defaults to None when not provided."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(QueryProcessingResponse(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_rejection_message_surfaced_in_error_when_provided() -> None:
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(
query=None, rejection_message="Queries about X are not allowed."
),
"original query",
)
assert "Queries about X are not allowed." in str(exc_info.value)
def test_fallback_rejection_message_when_none() -> None:
"""No rejection_message → generic fallback used in OnyxError detail."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=None, rejection_message=None),
"original query",
)
assert "Your query was rejected." in str(exc_info.value)
def test_nonempty_query_rewrites_message_text() -> None:
result = _apply_query_processing_hook(
QueryProcessingResponse(query="rewritten query"), "original query"
)
assert result == "rewritten query"

View File

@@ -1,6 +1,5 @@
from typing import Any
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
@@ -11,12 +10,10 @@ def test_init_subclass_raises_for_missing_attrs() -> None:
class IncompleteSpec(HookPointSpec):
hook_point = HookPoint.QUERY_PROCESSING
# missing display_name, description, etc.
# missing display_name, description, payload_model, response_model, etc.
@property
def input_schema(self) -> dict[str, Any]:
return {}
class _Payload(BaseModel):
pass
@property
def output_schema(self) -> dict[str, Any]:
return {}
payload_model = _Payload
response_model = _Payload

View File

@@ -0,0 +1,678 @@
"""Unit tests for the hook executor."""
import json
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
# A valid QueryProcessingResponse payload — used by success-path tests.
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
def _make_hook(
*,
is_active: bool = True,
endpoint_url: str | None = "https://hook.example.com/query",
api_key: MagicMock | None = None,
timeout_seconds: float = 5.0,
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
hook.endpoint_url = endpoint_url
hook.api_key = api_key
hook.timeout_seconds = timeout_seconds
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
hook.hook_point = hook_point
return hook
def _make_api_key(value: str) -> MagicMock:
api_key = MagicMock()
api_key.get_value.return_value = value
return api_key
def _make_response(
*,
status_code: int = 200,
json_return: Any = _RESPONSE_PAYLOAD,
json_side_effect: Exception | None = None,
) -> MagicMock:
"""Build a response mock with controllable json() behaviour."""
response = MagicMock()
response.status_code = status_code
if json_side_effect is not None:
response.json.side_effect = json_side_effect
else:
response.json.return_value = json_return
return response
def _setup_client(
mock_client_cls: MagicMock,
*,
response: MagicMock | None = None,
side_effect: Exception | None = None,
) -> MagicMock:
"""Wire up the httpx.Client mock and return the inner client.
If side_effect is an httpx.HTTPStatusError, it is raised from
raise_for_status() (matching real httpx behaviour) and post() returns a
response mock with the matching status_code set. All other exceptions are
raised directly from post().
"""
mock_client = MagicMock()
if isinstance(side_effect, httpx.HTTPStatusError):
error_response = MagicMock()
error_response.status_code = side_effect.response.status_code
error_response.raise_for_status.side_effect = side_effect
mock_client.post = MagicMock(return_value=error_response)
else:
mock_client.post = MagicMock(
side_effect=side_effect, return_value=response if not side_effect else None
)
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
return mock_client
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def db_session() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# Early-exit guards (no HTTP call, no DB writes)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"hooks_available,hook",
[
# HOOKS_AVAILABLE=False exits before the DB lookup — hook is irrelevant.
pytest.param(False, None, id="hooks_not_available"),
pytest.param(True, None, id="hook_not_found"),
pytest.param(True, _make_hook(is_active=False), id="hook_inactive"),
pytest.param(True, _make_hook(endpoint_url=None), id="no_endpoint_url"),
],
)
def test_early_exit_returns_skipped_with_no_db_writes(
db_session: MagicMock,
hooks_available: bool,
hook: MagicMock | None,
) -> None:
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", hooks_available),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
):
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSkipped)
mock_update.assert_not_called()
mock_log.assert_not_called()
# ---------------------------------------------------------------------------
# Successful HTTP call
# ---------------------------------------------------------------------------
def test_success_returns_validated_model_and_sets_reachable(
db_session: MagicMock,
) -> None:
hook = _make_hook()
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
def test_success_skips_reachable_write_when_already_true(db_session: MagicMock) -> None:
"""Deduplication guard: a hook already at is_reachable=True that succeeds
must not trigger a DB write."""
hook = _make_hook(is_reachable=True)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, response=_make_response())
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
mock_update.assert_not_called()
def test_non_dict_json_response_is_a_failure(db_session: MagicMock) -> None:
"""response.json() returning a non-dict (e.g. list) must be treated as failure.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(json_return=["unexpected", "list"]),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-dict" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
def test_json_decode_failure_is_a_failure(db_session: MagicMock) -> None:
"""response.json() raising must be treated as failure with SOFT strategy.
The server responded, so is_reachable is not updated."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response(
json_side_effect=json.JSONDecodeError("not JSON", "", 0)
),
)
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "non-JSON" in (log_kwargs["error_message"] or "")
mock_update.assert_not_called()
# ---------------------------------------------------------------------------
# HTTP failure paths
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"exception,fail_strategy,expected_type,expected_is_reachable",
[
# NetworkError → is_reachable=False
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="connect_error_soft",
),
pytest.param(
httpx.ConnectError("refused"),
HookFailStrategy.HARD,
OnyxError,
False,
id="connect_error_hard",
),
# 401/403 → is_reachable=False (api_key revoked)
pytest.param(
httpx.HTTPStatusError(
"401",
request=MagicMock(),
response=MagicMock(status_code=401, text="Unauthorized"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
False,
id="auth_401_soft",
),
pytest.param(
httpx.HTTPStatusError(
"403",
request=MagicMock(),
response=MagicMock(status_code=403, text="Forbidden"),
),
HookFailStrategy.HARD,
OnyxError,
False,
id="auth_403_hard",
),
# TimeoutException → no is_reachable write (None)
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="timeout_soft",
),
pytest.param(
httpx.TimeoutException("timeout"),
HookFailStrategy.HARD,
OnyxError,
None,
id="timeout_hard",
),
# Other HTTP errors → no is_reachable write (None)
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.SOFT,
HookSoftFailed,
None,
id="http_status_error_soft",
),
pytest.param(
httpx.HTTPStatusError(
"500",
request=MagicMock(),
response=MagicMock(status_code=500, text="error"),
),
HookFailStrategy.HARD,
OnyxError,
None,
id="http_status_error_hard",
),
],
)
def test_http_failure_paths(
db_session: MagicMock,
exception: Exception,
fail_strategy: HookFailStrategy,
expected_type: type,
expected_is_reachable: bool | None,
) -> None:
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=exception)
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, expected_type)
if expected_is_reachable is None:
mock_update.assert_not_called()
else:
mock_update.assert_called_once()
_, kwargs = mock_update.call_args
assert kwargs["is_reachable"] is expected_is_reachable
# ---------------------------------------------------------------------------
# Authorization header
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"api_key_value,expect_auth_header",
[
pytest.param("secret-token", True, id="api_key_present"),
pytest.param(None, False, id="api_key_absent"),
],
)
def test_authorization_header(
db_session: MagicMock,
api_key_value: str | None,
expect_auth_header: bool,
) -> None:
api_key = _make_api_key(api_key_value) if api_key_value else None
hook = _make_hook(api_key=api_key)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit"),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit"),
patch("httpx.Client") as mock_client_cls,
):
mock_client = _setup_client(mock_client_cls, response=_make_response())
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
_, call_kwargs = mock_client.post.call_args
if expect_auth_header:
assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key_value}"
else:
assert "Authorization" not in call_kwargs["headers"]
# ---------------------------------------------------------------------------
# Persist session failure
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"http_exception,expect_onyx_error",
[
pytest.param(None, False, id="success_path"),
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expect_onyx_error: bool,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor.get_session_with_current_tenant",
side_effect=RuntimeError("DB unavailable"),
),
patch("httpx.Client") as mock_client_cls,
):
_setup_client(
mock_client_cls,
response=_make_response() if not http_exception else None,
side_effect=http_exception,
)
if expect_onyx_error:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
# ---------------------------------------------------------------------------
# Response model validation
# ---------------------------------------------------------------------------
class _StrictResponse(BaseModel):
"""Strict model used to reliably trigger a ValidationError in tests."""
required_field: str # no default → missing key raises ValidationError
def _make_strict_spec() -> MagicMock:
spec = MagicMock()
spec.response_model = _StrictResponse
return spec
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
),
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
],
)
def test_response_validation_failure_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""A response that fails response_model validation is treated like any other
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch(
"onyx.hooks.executor.get_hook_point_spec",
return_value=_make_strict_spec(),
),
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
_setup_client(mock_client_cls, response=_make_response(json_return={}))
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
# is_reachable must not be updated — server responded correctly
mock_update.assert_not_called()
# failure must be logged
mock_log.assert_called_once()
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "validation" in (log_kwargs["error_message"] or "").lower()
# ---------------------------------------------------------------------------
# Outer soft-fail guard in execute_hook
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
],
)
def test_unexpected_exception_in_inner_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
if expected_type is HookSoftFailed:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
else:
with pytest.raises(ValueError, match="unexpected bug"):
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:
"""is_reachable update failing (e.g. concurrent hook deletion) must not
prevent the execution log from being written.
Simulates the production failure path: update_hook__no_commit raises
OnyxError(NOT_FOUND) as it would if the hook was concurrently deleted
between the initial lookup and the reachable update.
"""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch(
"onyx.hooks.executor.update_hook__no_commit",
side_effect=OnyxError(OnyxErrorCode.NOT_FOUND, "hook deleted"),
),
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch("httpx.Client") as mock_client_cls,
):
_setup_client(mock_client_cls, side_effect=httpx.ConnectError("refused"))
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
mock_log.assert_called_once()

View File

@@ -37,18 +37,20 @@ def test_input_schema_query_is_string() -> None:
def test_input_schema_user_email_is_nullable() -> None:
props = QueryProcessingSpec().input_schema["properties"]
assert "null" in props["user_email"]["type"]
# Pydantic v2 emits anyOf for nullable fields
assert any(s.get("type") == "null" for s in props["user_email"]["anyOf"])
def test_output_schema_query_is_required() -> None:
def test_output_schema_query_is_optional() -> None:
# query defaults to None (absent = reject); not required in the schema
schema = QueryProcessingSpec().output_schema
assert "query" in schema["required"]
assert "query" not in schema.get("required", [])
def test_output_schema_query_is_nullable() -> None:
# null means "reject the query"
# null means "reject the query"; Pydantic v2 emits anyOf for nullable fields
props = QueryProcessingSpec().output_schema["properties"]
assert "null" in props["query"]["type"]
assert any(s.get("type") == "null" for s in props["query"]["anyOf"])
def test_output_schema_rejection_message_is_optional() -> None:

View File

@@ -0,0 +1,278 @@
"""Unit tests for onyx.server.features.hooks.api helpers.
Covers:
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
- _validate_endpoint: httpx exception → HookValidateStatus mapping
ConnectTimeout → cannot_connect (TCP handshake never completed)
ConnectError → cannot_connect (DNS / TLS failure)
ReadTimeout et al. → timeout (TCP connected, server slow)
Any other exc → cannot_connect
- _raise_for_validation_failure: HookValidateStatus → OnyxError mapping
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.models import HookValidateResponse
from onyx.hooks.models import HookValidateStatus
from onyx.server.features.hooks.api import _check_ssrf_safety
from onyx.server.features.hooks.api import _raise_for_validation_failure
from onyx.server.features.hooks.api import _validate_endpoint
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_URL = "https://example.com/hook"
_API_KEY = "secret"
_TIMEOUT = 5.0
def _mock_response(status_code: int) -> MagicMock:
response = MagicMock()
response.status_code = status_code
return response
# ---------------------------------------------------------------------------
# _check_ssrf_safety
# ---------------------------------------------------------------------------
class TestCheckSsrfSafety:
def _call(self, url: str) -> None:
_check_ssrf_safety(url)
# --- scheme checks ---
def test_https_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
@pytest.mark.parametrize(
"url", ["http://example.com/hook", "ftp://example.com/hook"]
)
def test_non_https_scheme_rejected(self, url: str) -> None:
with pytest.raises(OnyxError) as exc_info:
self._call(url)
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert "https" in (exc_info.value.detail or "").lower()
# --- private IP blocklist ---
@pytest.mark.parametrize(
"ip",
[
pytest.param("127.0.0.1", id="loopback"),
pytest.param("10.0.0.1", id="RFC1918-A"),
pytest.param("172.16.0.1", id="RFC1918-B"),
pytest.param("192.168.1.1", id="RFC1918-C"),
pytest.param("169.254.169.254", id="link-local-IMDS"),
pytest.param("100.64.0.1", id="shared-address-space"),
pytest.param("::1", id="IPv6-loopback"),
pytest.param("fc00::1", id="IPv6-ULA"),
pytest.param("fe80::1", id="IPv6-link-local"),
],
)
def test_private_ip_is_blocked(self, ip: str) -> None:
with (
patch("onyx.utils.url.socket.getaddrinfo") as mock_dns,
pytest.raises(OnyxError) as exc_info,
):
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
self._call("https://internal.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
assert ip in (exc_info.value.detail or "")
def test_public_ip_is_allowed(self) -> None:
with patch("onyx.utils.url.socket.getaddrinfo") as mock_dns:
mock_dns.return_value = [(None, None, None, None, ("93.184.216.34", 0))]
self._call("https://example.com/hook") # must not raise
def test_dns_resolution_failure_raises(self) -> None:
import socket
with (
patch(
"onyx.utils.url.socket.getaddrinfo",
side_effect=socket.gaierror("name not found"),
),
pytest.raises(OnyxError) as exc_info,
):
self._call("https://no-such-host.example.com/hook")
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
# ---------------------------------------------------------------------------
# _validate_endpoint
# ---------------------------------------------------------------------------
class TestValidateEndpoint:
def _call(self, *, api_key: str | None = _API_KEY) -> HookValidateResponse:
# Bypass SSRF check — tested separately in TestCheckSsrfSafety.
with patch("onyx.server.features.hooks.api._check_ssrf_safety"):
return _validate_endpoint(
endpoint_url=_URL,
api_key=api_key,
timeout_seconds=_TIMEOUT,
)
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_2xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(200)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_5xx_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(500)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize("status_code", [401, 403])
def test_401_403_returns_auth_failed(
self, mock_client_cls: MagicMock, status_code: int
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(status_code)
)
result = self._call()
assert result.status == HookValidateStatus.auth_failed
assert str(status_code) in (result.error_message or "")
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_4xx_non_auth_returns_passed(self, mock_client_cls: MagicMock) -> None:
mock_client_cls.return_value.__enter__.return_value.post.return_value = (
_mock_response(422)
)
assert self._call().status == HookValidateStatus.passed
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_timeout_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectTimeout("timed out")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
@pytest.mark.parametrize(
"exc",
[
httpx.ReadTimeout("read timeout"),
httpx.WriteTimeout("write timeout"),
httpx.PoolTimeout("pool timeout"),
],
)
def test_read_write_pool_timeout_returns_timeout(
self, mock_client_cls: MagicMock, exc: httpx.TimeoutException
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = exc
assert self._call().status == HookValidateStatus.timeout
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_connect_error_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
# Covers DNS failures, TLS errors, and other connection-level errors.
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
httpx.ConnectError("name resolution failed")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_arbitrary_exception_returns_cannot_connect(
self, mock_client_cls: MagicMock
) -> None:
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
ConnectionRefusedError("refused")
)
assert self._call().status == HookValidateStatus.cannot_connect
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_api_key_sent_as_bearer(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key="mykey")
_, kwargs = mock_post.call_args
assert kwargs["headers"]["Authorization"] == "Bearer mykey"
@patch("onyx.server.features.hooks.api.httpx.Client")
def test_no_api_key_omits_auth_header(self, mock_client_cls: MagicMock) -> None:
mock_post = mock_client_cls.return_value.__enter__.return_value.post
mock_post.return_value = _mock_response(200)
self._call(api_key=None)
_, kwargs = mock_post.call_args
assert "Authorization" not in kwargs["headers"]
# ---------------------------------------------------------------------------
# _raise_for_validation_failure
# ---------------------------------------------------------------------------
class TestRaiseForValidationFailure:
@pytest.mark.parametrize(
"status, expected_code",
[
(HookValidateStatus.auth_failed, OnyxErrorCode.CREDENTIAL_INVALID),
(HookValidateStatus.timeout, OnyxErrorCode.GATEWAY_TIMEOUT),
(HookValidateStatus.cannot_connect, OnyxErrorCode.BAD_GATEWAY),
],
)
def test_raises_correct_error_code(
self, status: HookValidateStatus, expected_code: OnyxErrorCode
) -> None:
validation = HookValidateResponse(status=status, error_message="some error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.error_code == expected_code
def test_auth_failed_passes_error_message_directly(self) -> None:
validation = HookValidateResponse(
status=HookValidateStatus.auth_failed, error_message="bad credentials"
)
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "bad credentials"
@pytest.mark.parametrize(
"status", [HookValidateStatus.timeout, HookValidateStatus.cannot_connect]
)
def test_timeout_and_cannot_connect_wrap_error_message(
self, status: HookValidateStatus
) -> None:
validation = HookValidateResponse(status=status, error_message="raw error")
with pytest.raises(OnyxError) as exc_info:
_raise_for_validation_failure(validation)
assert exc_info.value.detail == "Endpoint validation failed: raw error"
# ---------------------------------------------------------------------------
# HookValidateStatus enum string values (API contract)
# ---------------------------------------------------------------------------
class TestHookValidateStatusValues:
@pytest.mark.parametrize(
"status, expected",
[
(HookValidateStatus.passed, "passed"),
(HookValidateStatus.auth_failed, "auth_failed"),
(HookValidateStatus.timeout, "timeout"),
(HookValidateStatus.cannot_connect, "cannot_connect"),
],
)
def test_string_values(self, status: HookValidateStatus, expected: str) -> None:
assert status == expected

View File

@@ -489,20 +489,18 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -290,25 +290,20 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -314,21 +314,19 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/sslcerts:/etc/nginx/sslcerts
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod.no-letsencrypt"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod.no-letsencrypt"
env_file:
- .env.nginx
environment:

View File

@@ -333,25 +333,20 @@ services:
- "80:80"
- "443:443"
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
# sleep a little bit to allow the web_server / api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.prod"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template.prod"
env_file:
- .env.nginx
environment:

View File

@@ -202,20 +202,18 @@ services:
ports:
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
- ../data/nginx:/nginx-templates:ro
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
minio:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1

View File

@@ -477,7 +477,10 @@ services:
- "${HOST_PORT_80:-80}:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
# Mount templates read-only; the startup command copies them into
# the writable /etc/nginx/conf.d/ inside the container. This avoids
# "Permission denied" errors on Windows Docker bind mounts.
- ../data/nginx:/nginx-templates:ro
# PRODUCTION: Add SSL certificate volumes for HTTPS support:
# - ../data/certbot/conf:/etc/letsencrypt
# - ../data/certbot/www:/var/www/certbot
@@ -489,12 +492,13 @@ services:
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not receive any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
# PRODUCTION: Change to app.conf.template.prod for production nginx config
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
/bin/sh -c "rm -f /etc/nginx/conf.d/default.conf
&& cp -a /nginx-templates/. /etc/nginx/conf.d/
&& sed 's/\r$//' /etc/nginx/conf.d/run-nginx.sh > /tmp/run-nginx.sh
&& chmod +x /tmp/run-nginx.sh
&& /tmp/run-nginx.sh app.conf.template"
cache:
image: redis:7.4-alpine

View File

@@ -2,7 +2,7 @@
# Usage: .\install.ps1 [OPTIONS]
# Remote (with params):
# & ([scriptblock]::Create((irm https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.ps1))) -Lite -NoPrompt
# Remote (defaults only):
# Remote (defaults only, configure via interaction during script):
# irm https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.ps1 | iex
param(
@@ -57,11 +57,7 @@ function Print-Step {
}
function Test-Interactive {
if ($NoPrompt) { return $false }
try {
if ([Console]::IsInputRedirected) { return $false }
return $true
} catch { return [Environment]::UserInteractive }
return -not $NoPrompt
}
function Prompt-OrDefault {
@@ -89,12 +85,12 @@ function Prompt-VersionTag {
Write-Host " - Type a specific tag (e.g., craft-v1.0.0)"
$version = Prompt-OrDefault "Enter tag [default: craft-latest]" "craft-latest"
} else {
Write-Host " - Press Enter for latest (recommended)"
Write-Host " - Press Enter for edge (recommended)"
Write-Host " - Type a specific tag (e.g., v0.1.0)"
$version = Prompt-OrDefault "Enter tag [default: latest]" "latest"
$version = Prompt-OrDefault "Enter tag [default: edge]" "edge"
}
if ($script:IncludeCraftMode -and $version -eq "craft-latest") { Print-Info "Selected: craft-latest (Craft enabled)" }
elseif ($version -eq "latest") { Print-Info "Selected: Latest tag" }
elseif ($version -eq "edge") { Print-Info "Selected: edge (latest nightly)" }
else { Print-Info "Selected: $version" }
return $version
}
@@ -103,16 +99,16 @@ function Prompt-DeploymentMode {
param([string]$LiteOverlayPath)
if ($script:LiteMode) { Print-Info "Deployment mode: Lite (set via -Lite flag)"; return }
Print-Info "Which deployment mode would you like?"
Write-Host " 1) Standard - Full deployment with search, connectors, and RAG"
Write-Host " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
Write-Host " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
Write-Host " LLM chat, tools, file uploads, and Projects still work"
Write-Host " 2) Standard - Full deployment with search, connectors, and RAG"
$modeChoice = Prompt-OrDefault "Choose a mode (1 or 2) [default: 1]" "1"
if ($modeChoice -eq "2") {
Print-Info "Selected: Standard mode"
} else {
$script:LiteMode = $true
Print-Info "Selected: Lite mode"
if (-not (Ensure-OnyxFile $LiteOverlayPath "$($script:GitHubRawUrl)/$($script:LiteComposeFile)" $script:LiteComposeFile)) { exit 1 }
} else {
Print-Info "Selected: Standard mode"
}
}
@@ -358,7 +354,8 @@ function Invoke-OnyxShutdown {
return
}
if (-not (Initialize-ComposeCommand)) { Print-OnyxError "Docker Compose not found."; exit 1 }
$result = Invoke-Compose -AutoDetect stop
$stopArgs = @("stop")
$result = Invoke-Compose -AutoDetect @stopArgs
if ($result -ne 0) { Print-OnyxError "Failed to stop containers"; exit 1 }
Print-Success "Onyx containers stopped (paused)"
}
@@ -375,7 +372,8 @@ function Invoke-OnyxDeleteData {
}
$deployDir = Join-Path $script:InstallRoot "deployment"
if ((Test-Path (Join-Path $deployDir "docker-compose.yml")) -and (Initialize-ComposeCommand)) {
$result = Invoke-Compose -AutoDetect down -v
$downArgs = @("down", "-v")
$result = Invoke-Compose -AutoDetect @downArgs
if ($result -eq 0) { Print-Success "Containers and volumes removed" }
else { Print-OnyxError "Failed to remove containers" }
}
@@ -934,15 +932,6 @@ function Main {
$liteOverlayPath = Join-Path $deploymentDir $script:LiteComposeFile
if ($script:LiteMode) {
if (-not (Ensure-OnyxFile $liteOverlayPath "$($script:GitHubRawUrl)/$($script:LiteComposeFile)" $script:LiteComposeFile)) { exit 1 }
} elseif (Test-Path $liteOverlayPath) {
if (Test-Path (Join-Path $deploymentDir ".env")) {
Print-Warning "Existing lite overlay found but -Lite was not passed."
$reply = Prompt-OrDefault "Remove lite overlay and switch to standard mode? (y/N)" "n"
if ($reply -match '^[Yy]') { Remove-Item -Force $liteOverlayPath; Print-Info "Switched to standard mode" }
else { $script:LiteMode = $true; Print-Info "Keeping lite mode" }
} else {
Remove-Item -Force $liteOverlayPath
}
}
$envTemplateDest = Join-Path $deploymentDir "env.template"
@@ -962,7 +951,8 @@ function Main {
# Check if services are already running
if ((Test-Path $composeDest) -and (Initialize-ComposeCommand)) {
$running = @()
try { $running = @(Invoke-Compose -AutoDetect ps -q 2>$null | Where-Object { $_ }) } catch { }
$psArgs = @("ps", "-q")
try { $running = @(Invoke-Compose -AutoDetect @psArgs 2>$null | Where-Object { $_ }) } catch { }
if ($running.Count -gt 0) {
Print-OnyxError "Onyx services are currently running!"
Print-Info "Run '.\install.ps1 -Shutdown' first, then re-run this script."
@@ -1028,6 +1018,12 @@ function Main {
Print-Info "You can customize .env later for OAuth/SAML, AI models, domain settings, and Craft."
}
# Clean up stale lite overlay if standard mode was selected
if (-not $script:LiteMode -and (Test-Path $liteOverlayPath)) {
Remove-Item -Force $liteOverlayPath
Print-Info "Removed previous lite overlay (switching to standard mode)"
}
# ── Step 6: Check Ports ───────────────────────────────────────────────
Print-Step "Checking for available ports"
$availablePort = Find-AvailablePort 3000
@@ -1037,7 +1033,7 @@ function Main {
Print-Success "Using port $availablePort for nginx"
$currentImageTag = Get-EnvFileValue -Path $envFile -Key "IMAGE_TAG"
$useLatest = ($currentImageTag -eq "latest" -or $currentImageTag -match '^craft-')
$useLatest = ($currentImageTag -eq "edge" -or $currentImageTag -eq "latest" -or $currentImageTag -match '^craft-')
if ($useLatest) { Print-Info "Using '$currentImageTag' tag - will force pull and recreate containers" }
# For pinned version tags, re-download config files from that tag so the
@@ -1069,8 +1065,9 @@ function Main {
# ── Step 8: Start Services ────────────────────────────────────────────
Print-Step "Starting Onyx services"
Print-Info "Launching containers..."
if ($useLatest) { $upResult = Invoke-Compose up -d --pull always --force-recreate }
else { $upResult = Invoke-Compose up -d }
$upArgs = @("up", "-d")
if ($useLatest) { $upArgs += @("--pull", "always", "--force-recreate") }
$upResult = Invoke-Compose @upArgs
if ($upResult -ne 0) { Print-OnyxError "Failed to start Onyx services"; exit 1 }
# ── Step 9: Container Health ──────────────────────────────────────────
@@ -1078,7 +1075,8 @@ function Main {
Start-Sleep -Seconds 10
$restartIssues = $false
$containerIds = @()
try { $containerIds = @(Invoke-Compose ps -q 2>$null | Where-Object { $_ }) } catch { }
$psArgs = @("ps", "-q")
try { $containerIds = @(Invoke-Compose @psArgs 2>$null | Where-Object { $_ }) } catch { }
foreach ($cid in $containerIds) {
if ([string]::IsNullOrWhiteSpace($cid)) { continue }

View File

@@ -96,8 +96,8 @@ fi
# When --lite is passed as a flag, lower resource thresholds early (before the
# resource check). When lite is chosen interactively, the thresholds are adjusted
# inside the new-deployment flow, after the resource check has already passed
# with the standard thresholds — which is the safer direction.
# after the resource check has already passed with the standard thresholds —
# which is the safer direction.
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
@@ -110,9 +110,6 @@ LITE_COMPOSE_FILE="docker-compose.onyx-lite.yml"
# Build the -f flags for docker compose.
# Pass "true" as $1 to auto-detect a previously-downloaded lite overlay
# (used by shutdown/delete-data so users don't need to remember --lite).
# Without the argument, the lite overlay is only included when --lite was
# explicitly passed — preventing install/start from silently staying in
# lite mode just because the file exists on disk from a prior run.
compose_file_args() {
local auto_detect="${1:-false}"
local args="-f docker-compose.yml"
@@ -177,7 +174,7 @@ ensure_file() {
# --- Interactive prompt helpers ---
is_interactive() {
[[ "$NO_PROMPT" = false ]] && [[ -t 0 ]]
[[ "$NO_PROMPT" = false ]]
}
prompt_or_default() {
@@ -745,25 +742,48 @@ if [ "$COMPOSE_VERSION" != "dev" ] && version_compare "$COMPOSE_VERSION" "2.24.0
print_info "Proceeding with installation despite Docker Compose version compatibility issues..."
fi
# Handle lite overlay: ensure it if --lite, clean up stale copies otherwise
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo " 2) Standard - Full deployment with search, connectors, and RAG"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
print_info "Selected: Standard mode"
;;
*)
LITE_MODE=true
print_info "Selected: Lite mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Handle lite overlay file based on selected mode
if [[ "$LITE_MODE" = true ]]; then
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
elif [[ -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" ]]; then
if [[ -f "${INSTALL_ROOT}/deployment/.env" ]]; then
print_warning "Existing lite overlay found but --lite was not passed."
prompt_yn_or_default "Remove lite overlay and switch to standard mode? (y/N): " "n"
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
print_info "Keeping existing lite overlay. Pass --lite to keep using lite mode."
LITE_MODE=true
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed lite overlay (switching to standard mode)"
fi
else
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
rm -f "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}"
print_info "Removed previous lite overlay (switching to standard mode)"
fi
ensure_file "${INSTALL_ROOT}/deployment/env.template" \
@@ -826,22 +846,22 @@ if [ -f "$ENV_FILE" ]; then
if [ "$REPLY" = "update" ]; then
print_info "Update selected. Which tag would you like to deploy?"
echo ""
echo "• Press Enter for latest (recommended)"
echo "• Press Enter for edge (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
if [ "$INCLUDE_CRAFT" = true ]; then
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
prompt_or_default "Enter tag [default: latest]: " "latest"
prompt_or_default "Enter tag [default: edge]: " "edge"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest version"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
else
print_info "Selected: $VERSION"
fi
@@ -893,45 +913,6 @@ else
print_info "No existing .env file found. Setting up new deployment..."
echo ""
# Ask for deployment mode (standard vs lite) unless already set via --lite flag
if [[ "$LITE_MODE" = false ]]; then
print_info "Which deployment mode would you like?"
echo ""
echo " 1) Standard - Full deployment with search, connectors, and RAG"
echo " 2) Lite - Minimal deployment (no Vespa, Redis, or model servers)"
echo " LLM chat, tools, file uploads, and Projects still work"
echo ""
prompt_or_default "Choose a mode (1 or 2) [default: 1]: " "1"
echo ""
case "$REPLY" in
2)
LITE_MODE=true
print_info "Selected: Lite mode"
ensure_file "${INSTALL_ROOT}/deployment/${LITE_COMPOSE_FILE}" \
"${GITHUB_RAW_URL}/${LITE_COMPOSE_FILE}" "${LITE_COMPOSE_FILE}" || exit 1
;;
*)
print_info "Selected: Standard mode"
;;
esac
else
print_info "Deployment mode: Lite (set via --lite flag)"
fi
# Validate lite + craft combination (could now be set interactively)
if [[ "$LITE_MODE" = true ]] && [[ "$INCLUDE_CRAFT" = true ]]; then
print_error "--include-craft cannot be used with Lite mode."
print_info "Craft requires services (Vespa, Redis, background workers) that lite mode disables."
exit 1
fi
# Adjust resource expectations for lite mode
if [[ "$LITE_MODE" = true ]]; then
EXPECTED_DOCKER_RAM_GB=4
EXPECTED_DISK_GB=16
fi
# Ask for version
print_info "Which tag would you like to deploy?"
echo ""
@@ -942,18 +923,18 @@ else
prompt_or_default "Enter tag [default: craft-latest]: " "craft-latest"
VERSION="$REPLY"
else
echo "• Press Enter for latest (recommended)"
echo "• Press Enter for edge (recommended)"
echo "• Type a specific tag (e.g., v0.1.0)"
echo ""
prompt_or_default "Enter tag [default: latest]: " "latest"
prompt_or_default "Enter tag [default: edge]: " "edge"
VERSION="$REPLY"
fi
echo ""
if [ "$INCLUDE_CRAFT" = true ] && [ "$VERSION" = "craft-latest" ]; then
print_info "Selected: craft-latest (Craft enabled)"
elif [ "$VERSION" = "latest" ]; then
print_info "Selected: Latest tag"
elif [ "$VERSION" = "edge" ]; then
print_info "Selected: edge (latest nightly)"
else
print_info "Selected: $VERSION"
fi
@@ -1111,15 +1092,15 @@ fi
export HOST_PORT=$AVAILABLE_PORT
print_success "Using port $AVAILABLE_PORT for nginx"
# Determine if we're using the latest tag or a craft tag (both should force pull)
# Determine if we're using a floating tag (edge, latest, craft-*) that should force pull
# Read IMAGE_TAG from .env file and remove any quotes or whitespace
CURRENT_IMAGE_TAG=$(grep "^IMAGE_TAG=" "$ENV_FILE" | head -1 | cut -d'=' -f2 | tr -d ' "'"'"'')
if [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
if [ "$CURRENT_IMAGE_TAG" = "edge" ] || [ "$CURRENT_IMAGE_TAG" = "latest" ] || [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
USE_LATEST=true
if [[ "$CURRENT_IMAGE_TAG" == craft-* ]]; then
print_info "Using craft tag '$CURRENT_IMAGE_TAG' - will force pull and recreate containers"
else
print_info "Using 'latest' tag - will force pull and recreate containers"
print_info "Using '$CURRENT_IMAGE_TAG' tag - will force pull and recreate containers"
fi
else
USE_LATEST=false

View File

@@ -1,38 +1,33 @@
"use client";
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
interface ActionsContainerProps {
type: "head" | "cell";
children: React.ReactNode;
size?: TableSize;
/** Pass-through click handler (e.g. stopPropagation on body cells). */
onClick?: (e: React.MouseEvent) => void;
children: React.ReactNode;
}
export default function ActionsContainer({
type,
children,
size,
onClick,
}: ActionsContainerProps) {
const size = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const Tag = type === "head" ? "th" : "td";
return (
<Tag
className="tbl-actions"
data-type={type}
data-size={size}
data-size={resolvedSize}
onClick={onClick}
>
<div
className={cn(
"flex h-full items-center",
type === "cell" ? "justify-end" : "justify-center"
)}
>
{children}
</div>
<div className="flex h-full items-center justify-center">{children}</div>
</Tag>
);
}

View File

@@ -8,7 +8,6 @@ import {
type SortingState,
} from "@tanstack/react-table";
import { Button, LineItemButton } from "@opal/components";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import { SvgArrowUpDown, SvgSortOrder, SvgCheck } from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import Divider from "@/refresh-components/Divider";
@@ -21,6 +20,7 @@ import Text from "@/refresh-components/texts/Text";
interface SortingPopoverProps<TData extends RowData = RowData> {
table: Table<TData>;
sorting: SortingState;
size?: "md" | "lg";
footerText?: string;
ascendingLabel?: string;
descendingLabel?: string;
@@ -29,11 +29,11 @@ interface SortingPopoverProps<TData extends RowData = RowData> {
function SortingPopover<TData extends RowData>({
table,
sorting,
size = "lg",
footerText,
ascendingLabel = "Ascending",
descendingLabel = "Descending",
}: SortingPopoverProps<TData>) {
const size = useTableSize();
const [open, setOpen] = useState(false);
const sortableColumns = table
.getAllLeafColumns()
@@ -158,6 +158,7 @@ function SortingPopover<TData extends RowData>({
// ---------------------------------------------------------------------------
interface CreateSortingColumnOptions {
size?: "md" | "lg";
footerText?: string;
ascendingLabel?: string;
descendingLabel?: string;
@@ -176,6 +177,7 @@ function createSortingColumn<TData>(
<SortingPopover
table={table}
sorting={table.getState().sorting}
size={options?.size}
footerText={options?.footerText}
ascendingLabel={options?.ascendingLabel}
descendingLabel={options?.descendingLabel}

View File

@@ -8,7 +8,6 @@ import {
type VisibilityState,
} from "@tanstack/react-table";
import { Button, LineItemButton, Tag } from "@opal/components";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import { SvgColumn, SvgCheck } from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import Divider from "@/refresh-components/Divider";
@@ -20,13 +19,14 @@ import Divider from "@/refresh-components/Divider";
interface ColumnVisibilityPopoverProps<TData extends RowData = RowData> {
table: Table<TData>;
columnVisibility: VisibilityState;
size?: "md" | "lg";
}
function ColumnVisibilityPopover<TData extends RowData>({
table,
columnVisibility,
size = "lg",
}: ColumnVisibilityPopoverProps<TData>) {
const size = useTableSize();
const [open, setOpen] = useState(false);
// User-defined columns only (exclude internal qualifier/actions)
@@ -87,7 +87,13 @@ function ColumnVisibilityPopover<TData extends RowData>({
// Column definition factory
// ---------------------------------------------------------------------------
function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
interface CreateColumnVisibilityColumnOptions {
size?: "md" | "lg";
}
function createColumnVisibilityColumn<TData>(
options?: CreateColumnVisibilityColumnOptions
): ColumnDef<TData, unknown> {
return {
id: "__columnVisibility",
size: 44,
@@ -98,6 +104,7 @@ function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
<ColumnVisibilityPopover
table={table}
columnVisibility={table.getState().columnVisibility}
size={options?.size}
/>
),
cell: () => null,

View File

@@ -57,10 +57,9 @@ function DragOverlayRowInner<TData>({
<QualifierContainer key={cell.id} type="cell">
<TableQualifier
content={qualifierColumn.content}
icon={qualifierColumn.getContent?.(row.original)}
initials={qualifierColumn.getInitials?.(row.original)}
icon={qualifierColumn.getIcon?.(row.original)}
imageSrc={qualifierColumn.getImageSrc?.(row.original)}
imageAlt={qualifierColumn.getImageAlt?.(row.original)}
background={qualifierColumn.background}
selectable={isSelectable}
selected={isSelectable && row.getIsSelected()}
/>

View File

@@ -1,8 +1,10 @@
"use client";
import { cn } from "@opal/utils";
import { Button, Pagination, SelectButton } from "@opal/components";
import Text from "@/refresh-components/texts/Text";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import { SvgEye, SvgXCircle } from "@opal/icons";
import type { ReactNode } from "react";
@@ -43,6 +45,9 @@ interface FooterSelectionModeProps {
onPageChange: (page: number) => void;
/** Unit label for count pagination. @default "items" */
units?: string;
/** Controls overall footer sizing. `"lg"` (default) or `"md"`. */
size?: TableSize;
className?: string;
}
/**
@@ -68,6 +73,7 @@ interface FooterSummaryModeProps {
leftExtra?: ReactNode;
/** Unit label for the summary text, e.g. "users". */
units?: string;
className?: string;
}
/**
@@ -104,7 +110,11 @@ export default function Footer(props: FooterProps) {
const isSmall = resolvedSize === "md";
return (
<div
className="table-footer flex w-full items-center justify-between border-t border-border-01"
className={cn(
"table-footer",
"flex w-full items-center justify-between border-t border-border-01",
props.className
)}
data-size={resolvedSize}
>
{/* Left side */}

View File

@@ -1,10 +1,10 @@
"use client";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
interface QualifierContainerProps {
type: "head" | "cell";
children?: React.ReactNode;
size?: TableSize;
/** Pass-through click handler (e.g. stopPropagation on body cells). */
onClick?: (e: React.MouseEvent) => void;
}
@@ -12,9 +12,11 @@ interface QualifierContainerProps {
export default function QualifierContainer({
type,
children,
size,
onClick,
}: QualifierContainerProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const Tag = type === "head" ? "th" : "td";

View File

@@ -7,7 +7,6 @@ row selection, drag-and-drop reordering, and server-side mode.
```tsx
import { Table, createTableColumns } from "@opal/components";
import { SvgUser } from "@opal/icons";
interface User {
id: string;
@@ -19,10 +18,11 @@ interface User {
const tc = createTableColumns<User>();
const columns = [
tc.qualifier({ content: "icon", getContent: () => SvgUser }),
tc.qualifier({ content: "avatar-user", getInitials: (r) => r.name?.[0] ?? "?" }),
tc.column("email", {
header: "Name",
weight: 22,
minWidth: 140,
cell: (email, row) => <span>{row.name ?? email}</span>,
}),
tc.column("status", {
@@ -40,7 +40,7 @@ function UsersTable({ users }: { users: User[] }) {
columns={columns}
getRowId={(r) => r.id}
pageSize={10}
footer={{}}
footer={{ mode: "summary" }}
/>
);
}
@@ -55,7 +55,7 @@ function UsersTable({ users }: { users: User[] }) {
| `getRowId` | `(row: TData) => string` | required | Unique row identifier |
| `pageSize` | `number` | `10` | Rows per page (`Infinity` disables pagination) |
| `size` | `"md" \| "lg"` | `"lg"` | Density variant |
| `footer` | `DataTableFooterConfig` | — | Footer configuration (mode is derived from `selectionBehavior`) |
| `footer` | `DataTableFooterConfig` | — | Footer mode (`"selection"` or `"summary"`) |
| `initialSorting` | `SortingState` | — | Initial sort state |
| `initialColumnVisibility` | `VisibilityState` | — | Initial column visibility |
| `draggable` | `DataTableDraggableConfig` | — | Enable drag-and-drop reordering |
@@ -76,8 +76,7 @@ function UsersTable({ users }: { users: User[] }) {
- `tc.displayColumn(opts)` — non-accessor custom column
- `tc.actions(opts)` — trailing actions column with visibility/sorting popovers
## Footer
## Footer Modes
The footer mode is derived automatically from `selectionBehavior`:
- **Selection footer** (when `selectionBehavior` is `"single-select"` or `"multi-select"`) — shows selection count, optional view/clear buttons, count pagination
- **Summary footer** (when `selectionBehavior` is `"no-select"` or omitted) — shows "Showing X\~Y of Z", list pagination, optional extra element
- **`"selection"`** — shows selection count, optional view/clear buttons, count pagination
- **`"summary"`** — shows "Showing X~Y of Z", list pagination, optional extra element

View File

@@ -1,6 +1,5 @@
import type { Meta, StoryObj } from "@storybook/react";
import { Table, createTableColumns } from "@opal/components";
import { SvgUser } from "@opal/icons";
// ---------------------------------------------------------------------------
// Sample data
@@ -109,14 +108,17 @@ const tc = createTableColumns<User>();
const columns = [
tc.qualifier({
content: "icon",
getContent: () => SvgUser,
background: true,
content: "avatar-user",
getInitials: (r) =>
r.name
.split(" ")
.map((n) => n[0])
.join(""),
}),
tc.column("name", { header: "Name", weight: 25 }),
tc.column("email", { header: "Email", weight: 30 }),
tc.column("role", { header: "Role", weight: 15 }),
tc.column("status", { header: "Status", weight: 15 }),
tc.column("name", { header: "Name", weight: 25, minWidth: 120 }),
tc.column("email", { header: "Email", weight: 30, minWidth: 160 }),
tc.column("role", { header: "Role", weight: 15, minWidth: 80 }),
tc.column("status", { header: "Status", weight: 15, minWidth: 80 }),
tc.actions(),
];
@@ -140,7 +142,7 @@ export const Default: Story = {
columns={columns}
getRowId={(r) => r.id}
pageSize={8}
footer={{}}
footer={{ mode: "summary" }}
/>
),
};

View File

@@ -1,20 +1,24 @@
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import type { WithoutStyles } from "@/types";
interface TableCellProps
extends WithoutStyles<React.TdHTMLAttributes<HTMLTableCellElement>> {
children: React.ReactNode;
size?: TableSize;
/** Explicit pixel width for the cell. */
width?: number;
}
export default function TableCell({
size,
width,
children,
...props
}: TableCellProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
return (
<td
className="tbl-cell overflow-hidden"

View File

@@ -1,8 +1,5 @@
"use client";
import React from "react";
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { WithoutStyles } from "@/types";
import type { ExtremaSizeVariants, SizeVariants } from "@opal/types";
@@ -12,15 +9,20 @@ import type { ExtremaSizeVariants, SizeVariants } from "@opal/types";
type TableSize = Extract<SizeVariants, "md" | "lg">;
type TableVariant = "rows" | "cards";
type TableQualifier = "simple" | "avatar" | "icon";
type SelectionBehavior = "no-select" | "single-select" | "multi-select";
interface TableProps
extends WithoutStyles<React.TableHTMLAttributes<HTMLTableElement>> {
ref?: React.Ref<HTMLTableElement>;
/** Size preset for the table. @default "lg" */
size?: TableSize;
/** Visual row variant. @default "cards" */
variant?: TableVariant;
/** Row selection behavior. @default "no-select" */
selectionBehavior?: SelectionBehavior;
/** Leading qualifier column type. @default null */
qualifier?: TableQualifier;
/** Height behavior. `"fit"` = shrink to content, `"full"` = fill available space. */
heightVariant?: ExtremaSizeVariants;
/** Explicit pixel width for the table (e.g. from `table.getTotalSize()`).
@@ -36,13 +38,14 @@ interface TableProps
function Table({
ref,
size = "lg",
variant = "cards",
selectionBehavior = "no-select",
qualifier = "simple",
heightVariant,
width,
...props
}: TableProps) {
const size = useTableSize();
return (
<table
ref={ref}
@@ -51,6 +54,7 @@ function Table({
data-size={size}
data-variant={variant}
data-selection={selectionBehavior}
data-qualifier={qualifier}
data-height={heightVariant}
{...props}
/>
@@ -58,4 +62,10 @@ function Table({
}
export default Table;
export type { TableProps, TableSize, TableVariant, SelectionBehavior };
export type {
TableProps,
TableSize,
TableVariant,
TableQualifier,
SelectionBehavior,
};

View File

@@ -1,6 +1,7 @@
import { cn } from "@opal/utils";
import Text from "@/refresh-components/texts/Text";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import type { WithoutStyles } from "@/types";
import { Button } from "@opal/components";
import { SvgChevronDown, SvgChevronUp, SvgHandle, SvgSort } from "@opal/icons";
@@ -29,6 +30,8 @@ interface TableHeadCustomProps {
icon?: (sorted: SortDirection) => IconFunctionComponent;
/** Text alignment for the column. Defaults to `"left"`. */
alignment?: "left" | "center" | "right";
/** Cell density. `"md"` uses tighter padding for denser layouts. */
size?: TableSize;
/** Column width in pixels. Applied as an inline style on the `<th>`. */
width?: number;
/** When `true`, shows a bottom border on hover. Defaults to `true`. */
@@ -78,11 +81,13 @@ export default function TableHead({
resizable,
onResizeStart,
alignment = "left",
size,
width,
bottomBorder = true,
...thProps
}: TableHeadProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const isSmall = resolvedSize === "md";
return (
<th

View File

@@ -3,13 +3,19 @@
import React from "react";
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import { SvgUser } from "@opal/icons";
import type { IconFunctionComponent } from "@opal/types";
import type { QualifierContentType } from "@opal/components/table/types";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import Text from "@/refresh-components/texts/Text";
interface TableQualifierProps {
className?: string;
/** Content type displayed in the qualifier */
content: QualifierContentType;
/** Size variant */
size?: TableSize;
/** Disables interaction */
disabled?: boolean;
/** Whether to show a selection checkbox overlay */
@@ -18,33 +24,54 @@ interface TableQualifierProps {
selected?: boolean;
/** Called when the checkbox is toggled */
onSelectChange?: (selected: boolean) => void;
/** Icon component to render (for "icon" content). */
/** Icon component to render (for "icon" content type) */
icon?: IconFunctionComponent;
/** Image source URL (for "image" content). */
/** Image source URL (for "image" content type) */
imageSrc?: string;
/** Image alt text (for "image" content). */
/** Image alt text */
imageAlt?: string;
/** Show a tinted background container behind the content. */
background?: boolean;
/** User initials (for "avatar-user" content type) */
initials?: string;
}
const iconSizes = {
lg: 28,
md: 24,
lg: 16,
md: 14,
} as const;
function getOverlayStyles(selected: boolean, disabled: boolean) {
function getQualifierStyles(selected: boolean, disabled: boolean) {
if (disabled) {
return selected ? "flex bg-action-link-00" : "hidden";
return {
container: "bg-background-neutral-03",
icon: "stroke-text-02",
overlay: selected ? "flex bg-action-link-00" : "hidden",
overlayImage: selected ? "flex bg-mask-01 backdrop-blur-02" : "hidden",
};
}
if (selected) {
return "flex bg-action-link-00";
return {
container: "bg-action-link-00",
icon: "stroke-text-03",
overlay: "flex bg-action-link-00",
overlayImage: "flex bg-mask-01 backdrop-blur-02",
};
}
return "flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-background-tint-01";
return {
container: "bg-background-tint-01",
icon: "stroke-text-03",
overlay:
"flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-background-tint-01",
overlayImage:
"flex opacity-0 group-hover/row:opacity-100 group-focus-within/row:opacity-100 bg-mask-01 group-hover/row:backdrop-blur-02 group-focus-within/row:backdrop-blur-02",
};
}
function TableQualifier({
className,
content,
size,
disabled = false,
selectable = false,
selected = false,
@@ -52,67 +79,100 @@ function TableQualifier({
icon: Icon,
imageSrc,
imageAlt = "",
background = false,
initials,
}: TableQualifierProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const isRound = content === "avatar-icon" || content === "avatar-user";
const iconSize = iconSizes[resolvedSize];
const overlayStyles = getOverlayStyles(selected, disabled);
const styles = getQualifierStyles(selected, disabled);
function renderContent() {
switch (content) {
case "icon":
return Icon ? <Icon size={iconSize} /> : null;
return Icon ? <Icon size={iconSize} className={styles.icon} /> : null;
case "simple":
return null;
case "image":
return imageSrc ? (
<img
src={imageSrc}
alt={imageAlt}
className="h-full w-full rounded-08 object-cover"
className={cn(
"h-full w-full object-cover",
isRound ? "rounded-full" : "rounded-08"
)}
/>
) : null;
case "simple":
case "avatar-icon":
return <SvgUser size={iconSize} className={styles.icon} />;
case "avatar-user":
return (
<div
className={cn(
"flex items-center justify-center rounded-full bg-background-neutral-inverted-00",
resolvedSize === "lg" ? "h-7 w-7" : "h-6 w-6"
)}
>
<Text
inverted
secondaryAction
text05
className="select-none uppercase"
>
{initials}
</Text>
</div>
);
default:
return null;
}
}
const inner = renderContent();
const showBackground = background && content !== "simple";
return (
<div
className={cn(
"group relative inline-flex shrink-0 items-center justify-center",
resolvedSize === "lg" ? "h-9 w-9" : "h-7 w-7",
disabled ? "cursor-not-allowed" : "cursor-default"
disabled ? "cursor-not-allowed" : "cursor-default",
className
)}
>
{showBackground ? (
{/* Inner qualifier container — no background for "simple" */}
{content !== "simple" && (
<div
className={cn(
"flex items-center justify-center overflow-hidden rounded-08 transition-colors",
"flex items-center justify-center overflow-hidden transition-colors",
resolvedSize === "lg" ? "h-9 w-9" : "h-7 w-7",
disabled
? "bg-background-neutral-03"
: selected
? "bg-action-link-00"
: "bg-background-tint-01"
isRound ? "rounded-full" : "rounded-08",
styles.container,
content === "image" && disabled && !selected && "opacity-50"
)}
>
{inner}
{renderContent()}
</div>
) : (
inner
)}
{/* Selection overlay */}
{selectable && (
<div
className={cn(
"absolute inset-0 items-center justify-center rounded-08",
content === "simple" ? "flex" : overlayStyles
"absolute inset-0 items-center justify-center",
content === "simple"
? "flex"
: isRound
? "rounded-full"
: "rounded-08",
content === "simple"
? "flex"
: content === "image"
? styles.overlayImage
: styles.overlay
)}
>
<Checkbox

View File

@@ -2,6 +2,7 @@
import { cn } from "@opal/utils";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
import type { WithoutStyles } from "@/types";
import { useSortable } from "@dnd-kit/sortable";
import { CSS } from "@dnd-kit/utilities";
@@ -11,7 +12,7 @@ import { SvgHandle } from "@opal/icons";
// Types
// ---------------------------------------------------------------------------
export interface TableRowProps
interface TableRowProps
extends WithoutStyles<React.HTMLAttributes<HTMLTableRowElement>> {
ref?: React.Ref<HTMLTableRowElement>;
selected?: boolean;
@@ -21,6 +22,8 @@ export interface TableRowProps
sortableId?: string;
/** Show drag handle overlay. Defaults to true when sortableId is set. */
showDragHandle?: boolean;
/** Size variant for the drag handle */
size?: TableSize;
}
// ---------------------------------------------------------------------------
@@ -30,13 +33,15 @@ export interface TableRowProps
function SortableTableRow({
sortableId,
showDragHandle = true,
size,
selected,
disabled,
ref: _externalRef,
children,
...props
}: TableRowProps) {
const resolvedSize = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const {
attributes,
@@ -100,9 +105,10 @@ function SortableTableRow({
// Main component
// ---------------------------------------------------------------------------
export default function TableRow({
function TableRow({
sortableId,
showDragHandle,
size,
selected,
disabled,
ref,
@@ -113,6 +119,7 @@ export default function TableRow({
<SortableTableRow
sortableId={sortableId}
showDragHandle={showDragHandle}
size={size}
selected={selected}
disabled={disabled}
ref={ref}
@@ -131,3 +138,6 @@ export default function TableRow({
/>
);
}
export default TableRow;
export type { TableRowProps };

View File

@@ -25,14 +25,18 @@ import type { SortDirection } from "@opal/components/table/TableHead";
interface QualifierConfig<TData> {
/** Content type for body-row `<TableQualifier>`. @default "simple" */
content?: QualifierContentType;
/** Return the icon component to render for a row (for "icon" content). */
getContent?: (row: TData) => IconFunctionComponent;
/** Return the image URL to render for a row (for "image" content). */
/** Content type for the header `<TableQualifier>`. @default "simple" */
headerContentType?: QualifierContentType;
/** Extract initials from a row (for "avatar-user" content). */
getInitials?: (row: TData) => string;
/** Extract icon from a row (for "icon" / "avatar-icon" content). */
getIcon?: (row: TData) => IconFunctionComponent;
/** Extract image src from a row (for "image" content). */
getImageSrc?: (row: TData) => string;
/** Return the image alt text for a row (for "image" content). @default "" */
getImageAlt?: (row: TData) => string;
/** Show a tinted background container behind the content. @default false */
background?: boolean;
/** Whether to show selection checkboxes on the qualifier. @default true */
selectable?: boolean;
/** Whether to render qualifier content in the header. @default true */
header?: boolean;
}
// ---------------------------------------------------------------------------
@@ -54,6 +58,8 @@ interface DataColumnConfig<TData, TValue> {
icon?: (sorted: SortDirection) => IconFunctionComponent;
/** Column weight for proportional distribution. @default 20 */
weight?: number;
/** Minimum column width in pixels. @default 50 */
minWidth?: number;
}
// ---------------------------------------------------------------------------
@@ -126,9 +132,9 @@ interface TableColumnsBuilder<TData> {
* ```ts
* const tc = createTableColumns<TeamMember>();
* const columns = [
* tc.qualifier({ content: "icon", getContent: (r) => UserIcon }),
* tc.column("name", { header: "Name", weight: 23 }),
* tc.column("email", { header: "Email", weight: 28 }),
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
* tc.column("email", { header: "Email", weight: 28, minWidth: 150 }),
* tc.actions(),
* ];
* ```
@@ -156,10 +162,12 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
width: (size: TableSize) =>
size === "md" ? { fixed: 36 } : { fixed: 44 },
content,
getContent: config?.getContent,
headerContentType: config?.headerContentType,
getInitials: config?.getInitials,
getIcon: config?.getIcon,
getImageSrc: config?.getImageSrc,
getImageAlt: config?.getImageAlt,
background: config?.background,
selectable: config?.selectable,
header: config?.header,
};
},
@@ -175,6 +183,7 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
enableHiding = true,
icon,
weight = 20,
minWidth = 50,
} = config;
const def = helper.accessor(accessor as any, {
@@ -192,7 +201,7 @@ export function createTableColumns<TData>(): TableColumnsBuilder<TData> {
kind: "data",
id: accessor as string,
def,
width: { weight, minWidth: Math.max(header.length * 8 + 40, 80) },
width: { weight, minWidth },
icon,
};
},

View File

@@ -39,12 +39,15 @@ import type {
import type { TableSize } from "@opal/components/table/TableSizeContext";
// ---------------------------------------------------------------------------
// SelectionBehavior
// Qualifier × SelectionBehavior
// ---------------------------------------------------------------------------
type Qualifier = "simple" | "avatar" | "icon";
type SelectionBehavior = "no-select" | "single-select" | "multi-select";
export type DataTableProps<TData> = BaseDataTableProps<TData> & {
/** Leading qualifier column type. @default "simple" */
qualifier?: Qualifier;
/** Row selection behavior. @default "no-select" */
selectionBehavior?: SelectionBehavior;
};
@@ -128,8 +131,8 @@ function processColumns<TData>(
* ```tsx
* const tc = createTableColumns<TeamMember>();
* const columns = [
* tc.qualifier({ content: "icon", getContent: (r) => UserIcon }),
* tc.column("name", { header: "Name", weight: 23 }),
* tc.qualifier({ content: "avatar-user", getInitials: (r) => r.initials }),
* tc.column("name", { header: "Name", weight: 23, minWidth: 120 }),
* tc.column("email", { header: "Email", weight: 28 }),
* tc.actions(),
* ];
@@ -149,6 +152,7 @@ export function Table<TData>(props: DataTableProps<TData>) {
footer,
size = "lg",
variant = "cards",
qualifier = "simple",
selectionBehavior = "no-select",
onSelectionChange,
onRowClick,
@@ -162,15 +166,11 @@ export function Table<TData>(props: DataTableProps<TData>) {
const effectivePageSize = pageSize ?? (footer ? 10 : data.length);
// Whether the qualifier column should exist in the DOM.
// Derived from the column definitions: if a qualifier column exists with
// content !== "simple", always show it. If content === "simple" (or no
// qualifier column defined), show only for multi-select (checkboxes).
const qualifierColDef = columns.find(
(c): c is OnyxQualifierColumn<TData> => c.kind === "qualifier"
);
// "simple" only gets a qualifier column for multi-select (checkboxes).
// "simple" + no-select/single-select = no qualifier column — single-select
// uses row-level background coloring instead.
const hasQualifierColumn =
(qualifierColDef != null && qualifierColDef.content !== "simple") ||
selectionBehavior === "multi-select";
qualifier !== "simple" || selectionBehavior === "multi-select";
// 1. Process columns (memoized on columns + size)
const { tanstackColumns, widthConfig, qualifierColumn, columnKindMap } =
@@ -357,6 +357,7 @@ export function Table<TData>(props: DataTableProps<TData>) {
}}
>
<TableElement
size={size}
variant={variant}
selectionBehavior={selectionBehavior}
width={
@@ -418,12 +419,14 @@ export function Table<TData>(props: DataTableProps<TData>) {
columnVisibility={
table.getState().columnVisibility
}
size={size}
/>
)}
{actionsDef.showSorting !== false && (
<SortingPopover
table={table}
sorting={table.getState().sorting}
size={size}
footerText={actionsDef.sortingFooterText}
/>
)}
@@ -538,6 +541,12 @@ export function Table<TData>(props: DataTableProps<TData>) {
if (cellColDef?.kind === "qualifier") {
const qDef = cellColDef as OnyxQualifierColumn<TData>;
// Resolve content based on the qualifier prop:
// - "simple" renders nothing (checkbox only when selectable)
// - "avatar"/"icon" render from column config
const qualifierContent =
qualifier === "simple" ? "simple" : qDef.content;
return (
<QualifierContainer
key={cell.id}
@@ -545,11 +554,10 @@ export function Table<TData>(props: DataTableProps<TData>) {
onClick={(e) => e.stopPropagation()}
>
<TableQualifier
content={qDef.content}
icon={qDef.getContent?.(row.original)}
content={qualifierContent}
initials={qDef.getInitials?.(row.original)}
icon={qDef.getIcon?.(row.original)}
imageSrc={qDef.getImageSrc?.(row.original)}
imageAlt={qDef.getImageAlt?.(row.original)}
background={qDef.background}
selectable={showQualifierCheckbox}
selected={
showQualifierCheckbox && row.getIsSelected()

View File

@@ -277,7 +277,7 @@ function createSplitterResizeHandler(
* const { containerRef, columnWidths, createResizeHandler } = useColumnWidths({
* headers: table.getHeaderGroups()[0].headers,
* fixedColumnIds: new Set(["actions"]),
* columnMinWidths: { name: 72, status: 80 }, // auto-derived by columns.ts from header text
* columnMinWidths: { name: 120, status: 80 },
* });
* ```
*/

View File

@@ -25,7 +25,8 @@
/* ---- TableHead ---- */
.table-head {
@apply relative;
@apply relative sticky top-0 z-20;
background: var(--table-header-bg, transparent);
}
.table-head[data-size="lg"] {
@apply px-2 py-1;
@@ -129,7 +130,8 @@ table[data-variant="cards"] .tbl-row:has(:focus-visible) > td {
/* ---- QualifierContainer ---- */
.tbl-qualifier[data-type="head"] {
@apply w-px whitespace-nowrap py-1;
@apply w-px whitespace-nowrap py-1 sticky top-0 z-20;
background: var(--table-header-bg, transparent);
}
.tbl-qualifier[data-type="head"][data-size="md"] {
@apply py-0.5;
@@ -145,10 +147,11 @@ table[data-variant="cards"] .tbl-row:has(:focus-visible) > td {
/* ---- ActionsContainer ---- */
.tbl-actions {
@apply w-px whitespace-nowrap px-1;
@apply sticky right-0 w-px whitespace-nowrap px-1;
}
.tbl-actions[data-type="head"] {
@apply px-2 py-1;
@apply z-30 sticky top-0 px-2 py-1;
background: var(--table-header-bg, transparent);
}
/* ---- Footer ---- */

View File

@@ -30,7 +30,12 @@ export type ColumnWidth = DataColumnWidth | FixedColumnWidth;
// Column kind discriminant
// ---------------------------------------------------------------------------
export type QualifierContentType = "simple" | "icon" | "image";
export type QualifierContentType =
| "icon"
| "simple"
| "image"
| "avatar-icon"
| "avatar-user";
export type OnyxColumnKind = "qualifier" | "data" | "display" | "actions";
@@ -51,14 +56,18 @@ export interface OnyxQualifierColumn<TData> extends OnyxColumnBase<TData> {
kind: "qualifier";
/** Content type for body-row `<TableQualifier>`. */
content: QualifierContentType;
/** Return the icon component to render for a row (for "icon" content). */
getContent?: (row: TData) => IconFunctionComponent;
/** Return the image URL to render for a row (for "image" content). */
/** Content type for the header `<TableQualifier>`. @default "simple" */
headerContentType?: QualifierContentType;
/** Extract initials from a row (for "avatar-user" content). */
getInitials?: (row: TData) => string;
/** Extract icon from a row (for "icon" / "avatar-icon" content). */
getIcon?: (row: TData) => IconFunctionComponent;
/** Extract image src from a row (for "image" content). */
getImageSrc?: (row: TData) => string;
/** Return the image alt text for a row (for "image" content). @default "" */
getImageAlt?: (row: TData) => string;
/** Show a tinted background container behind the content. @default false */
background?: boolean;
/** Whether to show selection checkboxes on the qualifier. @default true */
selectable?: boolean;
/** Whether to render qualifier content in the header. @default true */
header?: boolean;
}
/** Data column — accessor-based column with sorting/resizing. */

View File

@@ -0,0 +1,320 @@
"use client";
import Text from "@/refresh-components/texts/Text";
import { Persona } from "./interfaces";
import { useRouter } from "next/navigation";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import { toast } from "@/hooks/useToast";
import { useState, useMemo, useEffect } from "react";
import { UniqueIdentifier } from "@dnd-kit/core";
import { DraggableTable } from "@/components/table/DraggableTable";
import {
deletePersona,
personaComparator,
togglePersonaFeatured,
togglePersonaVisibility,
} from "./lib";
import { FiEdit2 } from "react-icons/fi";
import { useUser } from "@/providers/UserProvider";
import { Button } from "@opal/components";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import { SvgAlertCircle, SvgTrash } from "@opal/icons";
import type { Route } from "next";
function PersonaTypeDisplay({ persona }: { persona: Persona }) {
if (persona.builtin_persona) {
return <Text as="p">Built-In</Text>;
}
if (persona.featured) {
return <Text as="p">Featured</Text>;
}
if (persona.is_public) {
return <Text as="p">Public</Text>;
}
if (persona.groups.length > 0 || persona.users.length > 0) {
return <Text as="p">Shared</Text>;
}
return (
<Text as="p">Personal {persona.owner && <>({persona.owner.email})</>}</Text>
);
}
export function PersonasTable({
personas,
refreshPersonas,
currentPage,
pageSize,
}: {
personas: Persona[];
refreshPersonas: () => void;
currentPage: number;
pageSize: number;
}) {
const router = useRouter();
const { refreshUser, isAdmin } = useUser();
const editablePersonas = useMemo(() => {
return personas.filter((p) => !p.builtin_persona);
}, [personas]);
const editablePersonaIds = useMemo(() => {
return new Set(editablePersonas.map((p) => p.id.toString()));
}, [editablePersonas]);
const [finalPersonas, setFinalPersonas] = useState<Persona[]>([]);
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
const [personaToDelete, setPersonaToDelete] = useState<Persona | null>(null);
const [defaultModalOpen, setDefaultModalOpen] = useState(false);
const [personaToToggleDefault, setPersonaToToggleDefault] =
useState<Persona | null>(null);
useEffect(() => {
const editable = editablePersonas.sort(personaComparator);
const nonEditable = personas
.filter((p) => !editablePersonaIds.has(p.id.toString()))
.sort(personaComparator);
setFinalPersonas([...editable, ...nonEditable]);
}, [editablePersonas, personas, editablePersonaIds]);
const updatePersonaOrder = async (orderedPersonaIds: UniqueIdentifier[]) => {
const reorderedPersonas = orderedPersonaIds.map(
(id) => personas.find((persona) => persona.id.toString() === id)!
);
setFinalPersonas(reorderedPersonas);
// Calculate display_priority based on current page.
// Page 1 (items 0-9): priorities 0-9
// Page 2 (items 10-19): priorities 10-19, etc.
const pageStartIndex = (currentPage - 1) * pageSize;
const displayPriorityMap = new Map<UniqueIdentifier, number>();
orderedPersonaIds.forEach((personaId, ind) => {
displayPriorityMap.set(personaId, pageStartIndex + ind);
});
const response = await fetch("/api/admin/agents/display-priorities", {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
display_priority_map: Object.fromEntries(displayPriorityMap),
}),
});
if (!response.ok) {
toast.error(`Failed to update persona order - ${await response.text()}`);
setFinalPersonas(personas);
await refreshPersonas();
return;
}
await refreshPersonas();
await refreshUser();
};
const openDeleteModal = (persona: Persona) => {
setPersonaToDelete(persona);
setDeleteModalOpen(true);
};
const closeDeleteModal = () => {
setDeleteModalOpen(false);
setPersonaToDelete(null);
};
const handleDeletePersona = async () => {
if (personaToDelete) {
const response = await deletePersona(personaToDelete.id);
if (response.ok) {
refreshPersonas();
closeDeleteModal();
} else {
toast.error(`Failed to delete persona - ${await response.text()}`);
}
}
};
const openDefaultModal = (persona: Persona) => {
setPersonaToToggleDefault(persona);
setDefaultModalOpen(true);
};
const closeDefaultModal = () => {
setDefaultModalOpen(false);
setPersonaToToggleDefault(null);
};
const handleToggleDefault = async () => {
if (personaToToggleDefault) {
const response = await togglePersonaFeatured(
personaToToggleDefault.id,
personaToToggleDefault.featured
);
if (response.ok) {
refreshPersonas();
closeDefaultModal();
} else {
toast.error(`Failed to update persona - ${await response.text()}`);
}
}
};
return (
<div>
{deleteModalOpen && personaToDelete && (
<ConfirmationModalLayout
icon={SvgAlertCircle}
title="Delete Agent"
onClose={closeDeleteModal}
submit={<Button onClick={handleDeletePersona}>Delete</Button>}
>
{`Are you sure you want to delete ${personaToDelete.name}?`}
</ConfirmationModalLayout>
)}
{defaultModalOpen &&
personaToToggleDefault &&
(() => {
const isDefault = personaToToggleDefault.featured;
const title = isDefault
? "Remove Featured Agent"
: "Set Featured Agent";
const buttonText = isDefault ? "Remove Feature" : "Set as Featured";
const text = isDefault
? `Are you sure you want to remove the featured status of ${personaToToggleDefault.name}?`
: `Are you sure you want to set the featured status of ${personaToToggleDefault.name}?`;
const additionalText = isDefault
? `Removing "${personaToToggleDefault.name}" as a featured agent will not affect its visibility or accessibility.`
: `Setting "${personaToToggleDefault.name}" as a featured agent will make it public and visible to all users. This action cannot be undone.`;
return (
<ConfirmationModalLayout
icon={SvgAlertCircle}
title={title}
onClose={closeDefaultModal}
submit={
<Button onClick={handleToggleDefault}>{buttonText}</Button>
}
>
<div className="flex flex-col gap-2">
<Text as="p">{text}</Text>
<Text as="p" text03>
{additionalText}
</Text>
</div>
</ConfirmationModalLayout>
);
})()}
<DraggableTable
headers={[
"Name",
"Description",
"Type",
"Featured Agent",
"Is Visible",
"Delete",
]}
isAdmin={isAdmin}
rows={finalPersonas.map((persona) => {
const isEditable = editablePersonas.includes(persona);
return {
id: persona.id.toString(),
cells: [
<div key="name" className="flex">
{!persona.builtin_persona && (
<FiEdit2
className="mr-1 my-auto cursor-pointer"
onClick={() =>
router.push(
`/app/agents/edit/${
persona.id
}?u=${Date.now()}&admin=true` as Route
)
}
/>
)}
<p className="text font-medium whitespace-normal break-none">
{persona.name}
</p>
</div>,
<p
key="description"
className="whitespace-normal break-all max-w-2xl"
>
{persona.description}
</p>,
<PersonaTypeDisplay key={persona.id} persona={persona} />,
<div
key="featured"
onClick={() => {
openDefaultModal(persona);
}}
className={`
px-1 py-0.5 rounded flex hover:bg-accent-background-hovered cursor-pointer select-none w-fit items-center gap-2
`}
>
<div className="my-auto flex-none w-22">
{!persona.featured ? (
<div className="text-error">Not Featured</div>
) : (
"Featured"
)}
</div>
<Checkbox checked={persona.featured} />
</div>,
<div
key="is_visible"
onClick={async () => {
const response = await togglePersonaVisibility(
persona.id,
persona.is_visible
);
if (response.ok) {
refreshPersonas();
} else {
toast.error(
`Failed to update persona - ${await response.text()}`
);
}
}}
className={`
px-1 py-0.5 rounded flex hover:bg-accent-background-hovered cursor-pointer select-none w-fit items-center gap-2
`}
>
<div className="my-auto w-fit">
{!persona.is_visible ? (
<div className="text-error">Hidden</div>
) : (
"Visible"
)}
</div>
<Checkbox checked={persona.is_visible} />
</div>,
<div key="edit" className="flex">
<div className="mr-auto my-auto">
{!persona.builtin_persona && isEditable ? (
<Button
icon={SvgTrash}
prominence="tertiary"
onClick={() => openDeleteModal(persona)}
/>
) : (
<Text as="p">-</Text>
)}
</div>
</div>,
],
staticModifiers: [[1, "lg:w-[250px] xl:w-[400px] 2xl:w-[550px]"]],
};
})}
setRows={updatePersonaOrder}
/>
</div>
);
}

View File

@@ -1 +1,160 @@
export { default } from "@/refresh-pages/admin/AgentsPage";
"use client";
import { PersonasTable } from "./PersonaTable";
import Text from "@/components/ui/text";
import Title from "@/components/ui/title";
import Separator from "@/refresh-components/Separator";
import { SubLabel } from "@/components/Field";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { useAdminPersonas } from "@/hooks/useAdminPersonas";
import { Persona } from "./interfaces";
import { ThreeDotsLoader } from "@/components/Loading";
import { ErrorCallout } from "@/components/ErrorCallout";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import { useState, useEffect } from "react";
import { Pagination } from "@opal/components";
const route = ADMIN_ROUTES.AGENTS;
const PAGE_SIZE = 20;
function MainContent({
personas,
totalItems,
currentPage,
onPageChange,
refreshPersonas,
}: {
personas: Persona[];
totalItems: number;
currentPage: number;
onPageChange: (page: number) => void;
refreshPersonas: () => void;
}) {
// Filter out default/unified assistants.
// NOTE: The backend should already exclude them if includeDefault = false is
// provided. That change was made with the introduction of pagination; we keep
// this filter here for now for backwards compatibility.
const customPersonas = personas.filter((persona) => !persona.builtin_persona);
const totalPages = Math.ceil(totalItems / PAGE_SIZE);
// Clamp currentPage when totalItems shrinks (e.g., deleting the last item on a page)
useEffect(() => {
if (currentPage > totalPages && totalPages > 0) {
onPageChange(totalPages);
}
}, [currentPage, totalPages, onPageChange]);
return (
<div>
<Text className="mb-2">
Agents are a way to build custom search/question-answering experiences
for different use cases.
</Text>
<Text className="mt-2">They allow you to customize:</Text>
<div className="text-sm">
<ul className="list-disc mt-2 ml-4">
<li>
The prompt used by your LLM of choice to respond to the user query
</li>
<li>The documents that are used as context</li>
</ul>
</div>
<div>
<Separator />
<Title>Create an Agent</Title>
<CreateButton href="/app/agents/create?admin=true">
New Agent
</CreateButton>
<Separator />
<Title>Existing Agents</Title>
{totalItems > 0 ? (
<>
<SubLabel>
Agents will be displayed as options on the Chat / Search
interfaces in the order they are displayed below. Agents marked as
hidden will not be displayed. Editable agents are shown at the
top.
</SubLabel>
<PersonasTable
personas={customPersonas}
refreshPersonas={refreshPersonas}
currentPage={currentPage}
pageSize={PAGE_SIZE}
/>
{totalPages > 1 && (
<Pagination
currentPage={currentPage}
totalPages={totalPages}
onChange={onPageChange}
/>
)}
</>
) : (
<div className="mt-6 p-8 border border-border rounded-lg bg-background-weak text-center">
<Text className="text-lg font-medium mb-2">
No custom agents yet
</Text>
<Text className="text-subtle mb-3">
Create your first agent to:
</Text>
<ul className="text-subtle text-sm list-disc text-left inline-block mb-3">
<li>Build department-specific knowledge bases</li>
<li>Create specialized research agents</li>
<li>Set up compliance and policy advisors</li>
</ul>
<Text className="text-subtle text-sm mb-4">
...and so much more!
</Text>
<CreateButton href="/app/agents/create?admin=true">
Create Your First Agent
</CreateButton>
</div>
)}
</div>
</div>
);
}
export default function Page() {
const [currentPage, setCurrentPage] = useState(1);
const { personas, totalItems, isLoading, error, refresh } = useAdminPersonas({
pageNum: currentPage - 1, // Backend uses 0-indexed pages
pageSize: PAGE_SIZE,
});
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
{isLoading && <ThreeDotsLoader />}
{error && (
<ErrorCallout
errorTitle="Failed to load agents"
errorMsg={
error?.info?.message ||
error?.info?.detail ||
"An unknown error occurred"
}
/>
)}
{!isLoading && !error && (
<MainContent
personas={personas}
totalItems={totalItems}
currentPage={currentPage}
onPageChange={setCurrentPage}
refreshPersonas={refresh}
/>
)}
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -125,7 +125,7 @@ export const MAX_FILES_TO_SHOW = 3;
export const MOBILE_SIDEBAR_BREAKPOINT_PX = 640;
export const DESKTOP_SMALL_BREAKPOINT_PX = 912;
export const DESKTOP_MEDIUM_BREAKPOINT_PX = 1232;
export const DEFAULT_AVATAR_SIZE_PX = 18;
export const DEFAULT_AGENT_AVATAR_SIZE_PX = 18;
export const HORIZON_DISTANCE_PX = 800;
export const LOGO_FOLDED_SIZE_PX = 24;
export const LOGO_UNFOLDED_SIZE_PX = 88;

View File

@@ -1,55 +0,0 @@
import { getUserInitials } from "@/lib/user";
describe("getUserInitials", () => {
it("returns first letters of first two name parts", () => {
expect(getUserInitials("Alice Smith", "alice@example.com")).toBe("AS");
});
it("returns first two chars of a single-word name", () => {
expect(getUserInitials("Alice", "alice@example.com")).toBe("AL");
});
it("handles three-word names (uses first two)", () => {
expect(getUserInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
});
it("falls back to email local part with dot separator", () => {
expect(getUserInitials(null, "alice.smith@example.com")).toBe("AS");
});
it("falls back to email local part with underscore separator", () => {
expect(getUserInitials(null, "alice_smith@example.com")).toBe("AS");
});
it("falls back to email local part with hyphen separator", () => {
expect(getUserInitials(null, "alice-smith@example.com")).toBe("AS");
});
it("uses first two chars of email local if no separator", () => {
expect(getUserInitials(null, "alice@example.com")).toBe("AL");
});
it("returns null for empty email local part", () => {
expect(getUserInitials(null, "@example.com")).toBeNull();
});
it("uppercases the result", () => {
expect(getUserInitials("john doe", "jd@test.com")).toBe("JD");
});
it("trims whitespace from name", () => {
expect(getUserInitials(" Alice Smith ", "a@test.com")).toBe("AS");
});
it("returns null for numeric name parts", () => {
expect(getUserInitials("Alice 1st", "x@test.com")).toBeNull();
});
it("returns null for numeric email", () => {
expect(getUserInitials(null, "42@domain.com")).toBeNull();
});
it("falls back to email when name has non-alpha chars", () => {
expect(getUserInitials("A1", "alice@example.com")).toBe("AL");
});
});

View File

@@ -128,54 +128,3 @@ export function getUserDisplayName(user: User | null): string {
// If nothing works, then fall back to anonymous user name
return "Anonymous";
}
/**
* Derive display initials from a user's name or email.
*
* - If a name is provided, uses the first letter of the first two words.
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
* - Returns `null` when no valid alpha initials can be derived.
*/
export function getUserInitials(
name: string | null,
email: string
): string | null {
if (name) {
const words = name.trim().split(/\s+/);
if (words.length >= 2) {
const first = words[0]?.[0];
const second = words[1]?.[0];
if (first && second) {
const result = (first + second).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
return null;
}
if (name.trim().length >= 1) {
const result = name.trim().slice(0, 2).toUpperCase();
if (/^[A-Z]{1,2}$/.test(result)) return result;
}
}
const local = email.split("@")[0];
if (!local || local.length === 0) return null;
const parts = local.split(/[._-]/);
if (parts.length >= 2) {
const first = parts[0]?.[0];
const second = parts[1]?.[0];
if (first && second) {
const result = (first + second).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
return null;
}
if (local.length >= 2) {
const result = local.slice(0, 2).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
if (local.length === 1) {
const result = local.toUpperCase();
if (/^[A-Z]$/.test(result)) return result;
}
return null;
}

View File

@@ -4,7 +4,10 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { buildImgUrl } from "@/app/app/components/files/images/utils";
import { OnyxIcon } from "@/components/icons/icons";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { DEFAULT_AVATAR_SIZE_PX, DEFAULT_AGENT_ID } from "@/lib/constants";
import {
DEFAULT_AGENT_AVATAR_SIZE_PX,
DEFAULT_AGENT_ID,
} from "@/lib/constants";
import CustomAgentAvatar from "@/refresh-components/avatars/CustomAgentAvatar";
import Image from "next/image";
@@ -15,7 +18,7 @@ export interface AgentAvatarProps {
export default function AgentAvatar({
agent,
size = DEFAULT_AVATAR_SIZE_PX,
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
...props
}: AgentAvatarProps) {
const settings = useSettingsContext();

View File

@@ -4,7 +4,7 @@ import { cn } from "@/lib/utils";
import type { IconProps } from "@opal/types";
import Text from "@/refresh-components/texts/Text";
import Image from "next/image";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import { DEFAULT_AGENT_AVATAR_SIZE_PX } from "@/lib/constants";
import {
SvgActivitySmall,
SvgAudioEqSmall,
@@ -96,7 +96,7 @@ export default function CustomAgentAvatar({
src,
iconName,
size = DEFAULT_AVATAR_SIZE_PX,
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
}: CustomAgentAvatarProps) {
if (src) {
return (

View File

@@ -1,52 +0,0 @@
import { SvgUser } from "@opal/icons";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import { getUserInitials } from "@/lib/user";
import Text from "@/refresh-components/texts/Text";
import type { User } from "@/lib/types";
export interface UserAvatarProps {
user: User;
size?: number;
}
export default function UserAvatar({
user,
size = DEFAULT_AVATAR_SIZE_PX,
}: UserAvatarProps) {
const initials = getUserInitials(
user.personalization?.name ?? null,
user.email
);
if (!initials) {
return (
<div
role="img"
aria-label={`${user.email} avatar`}
className="flex items-center justify-center rounded-full bg-background-tint-01"
style={{ width: size, height: size }}
>
<SvgUser size={size * 0.55} className="stroke-text-03" aria-hidden />
</div>
);
}
return (
<div
role="img"
aria-label={`${user.email} avatar`}
className="flex items-center justify-center rounded-full bg-background-neutral-inverted-00"
style={{ width: size, height: size }}
>
<Text
inverted
secondaryAction
text05
className="select-none"
style={{ fontSize: size * 0.4 }}
>
{initials}
</Text>
</div>
);
}

View File

@@ -1,31 +0,0 @@
"use client";
import { SvgOnyxOctagon, SvgPlus } from "@opal/icons";
import { Button } from "@opal/components";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import Link from "next/link";
import AgentsTable from "./AgentsPage/AgentsTable";
// ---------------------------------------------------------------------------
// Page
// ---------------------------------------------------------------------------
export default function AgentsPage() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
title="Agents"
icon={SvgOnyxOctagon}
rightChildren={
<Link href="/app/agents/create?admin=true">
<Button icon={SvgPlus}>New Agent</Button>
</Link>
}
/>
<SettingsLayouts.Body>
<AgentsTable />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,215 +0,0 @@
"use client";
import { useState } from "react";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import {
SvgMoreHorizontal,
SvgEdit,
SvgEye,
SvgEyeClosed,
SvgStar,
SvgTrash,
SvgAlertCircle,
} from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import Separator from "@/refresh-components/Separator";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import { Section } from "@/layouts/general-layouts";
import Text from "@/refresh-components/texts/Text";
import { toast } from "@/hooks/useToast";
import { useRouter } from "next/navigation";
import {
deleteAgent,
toggleAgentFeatured,
toggleAgentVisibility,
} from "@/refresh-pages/admin/AgentsPage/svc";
import type { AgentRow } from "@/refresh-pages/admin/AgentsPage/interfaces";
import type { Route } from "next";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
enum Modal {
DELETE = "delete",
TOGGLE_FEATURED = "toggleFeatured",
}
interface AgentRowActionsProps {
agent: AgentRow;
onMutate: () => void;
}
// ---------------------------------------------------------------------------
// Component
// ---------------------------------------------------------------------------
export default function AgentRowActions({
agent,
onMutate,
}: AgentRowActionsProps) {
const router = useRouter();
const [modal, setModal] = useState<Modal | null>(null);
const [popoverOpen, setPopoverOpen] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
async function handleAction(
action: () => Promise<void>,
successMessage: string
) {
setIsSubmitting(true);
try {
await action();
onMutate();
toast.success(successMessage);
setModal(null);
} catch (err) {
toast.error(err instanceof Error ? err.message : "An error occurred");
} finally {
setIsSubmitting(false);
}
}
const openModal = (type: Modal) => {
setPopoverOpen(false);
setModal(type);
};
return (
<>
<Popover open={popoverOpen} onOpenChange={setPopoverOpen}>
<Popover.Trigger asChild>
<Button prominence="tertiary" icon={SvgMoreHorizontal} />
</Popover.Trigger>
<Popover.Content align="end" width="sm">
<Section
gap={0.5}
height="auto"
alignItems="stretch"
justifyContent="start"
>
{!agent.builtin_persona && (
<Button
prominence="tertiary"
icon={SvgEdit}
onClick={() => {
setPopoverOpen(false);
router.push(
`/app/agents/edit/${
agent.id
}?u=${Date.now()}&admin=true` as Route
);
}}
>
Edit Agent
</Button>
)}
<Button
prominence="tertiary"
icon={agent.is_visible ? SvgEyeClosed : SvgEye}
onClick={() => {
setPopoverOpen(false);
handleAction(
() => toggleAgentVisibility(agent.id, agent.is_visible),
agent.is_visible ? "Agent hidden" : "Agent visible"
);
}}
>
{agent.is_visible ? "Hide Agent" : "Show Agent"}
</Button>
<Button
prominence="tertiary"
icon={SvgStar}
onClick={() => openModal(Modal.TOGGLE_FEATURED)}
>
{agent.featured ? "Remove Featured" : "Set as Featured"}
</Button>
{!agent.builtin_persona && (
<>
<Separator paddingXRem={0.5} />
<Button
prominence="tertiary"
variant="danger"
icon={SvgTrash}
onClick={() => openModal(Modal.DELETE)}
>
Delete Agent
</Button>
</>
)}
</Section>
</Popover.Content>
</Popover>
{modal === Modal.DELETE && (
<ConfirmationModalLayout
icon={(props) => (
<SvgAlertCircle {...props} className="text-action-danger-05" />
)}
title="Delete Agent"
onClose={isSubmitting ? undefined : () => setModal(null)}
submit={
<Disabled disabled={isSubmitting}>
<Button
variant="danger"
onClick={() => {
handleAction(() => deleteAgent(agent.id), "Agent deleted");
}}
>
Delete
</Button>
</Disabled>
}
>
<Text as="p" text03>
Are you sure you want to delete{" "}
<Text as="span" text05>
{agent.name}
</Text>
? This action cannot be undone.
</Text>
</ConfirmationModalLayout>
)}
{modal === Modal.TOGGLE_FEATURED && (
<ConfirmationModalLayout
icon={SvgStar}
title={
agent.featured ? "Remove Featured Agent" : "Set Featured Agent"
}
onClose={isSubmitting ? undefined : () => setModal(null)}
submit={
<Disabled disabled={isSubmitting}>
<Button
onClick={() => {
handleAction(
() => toggleAgentFeatured(agent.id, agent.featured),
agent.featured
? "Agent removed from featured"
: "Agent set as featured"
);
}}
>
{agent.featured ? "Remove Featured" : "Set as Featured"}
</Button>
</Disabled>
}
>
<div className="flex flex-col gap-2">
<Text as="p" text03>
{agent.featured
? `Are you sure you want to remove the featured status of "${agent.name}"?`
: `Are you sure you want to set "${agent.name}" as a featured agent?`}
</Text>
<Text as="p" text03>
{agent.featured
? "Removing featured status will not affect its visibility or accessibility."
: "Setting as featured will make this agent public and visible to all users."}
</Text>
</div>
</ConfirmationModalLayout>
)}
</>
);
}

View File

@@ -1,209 +0,0 @@
"use client";
import { useMemo, useState } from "react";
import { Table, createTableColumns } from "@opal/components";
import { IllustrationContent } from "@opal/layouts";
import SvgNoResult from "@opal/illustrations/no-result";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { Tag } from "@opal/components";
import type { MinimalUserSnapshot } from "@/lib/types";
import AgentAvatar from "@/refresh-components/avatars/AgentAvatar";
import type { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { useAdminPersonas } from "@/hooks/useAdminPersonas";
import { toast } from "@/hooks/useToast";
import AgentRowActions from "@/refresh-pages/admin/AgentsPage/AgentRowActions";
import { updateAgentDisplayPriorities } from "@/refresh-pages/admin/AgentsPage/svc";
import type { AgentRow } from "@/refresh-pages/admin/AgentsPage/interfaces";
import type { Persona } from "@/app/admin/agents/interfaces";
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
function toAgentRow(persona: Persona): AgentRow {
return {
id: persona.id,
name: persona.name,
description: persona.description,
is_public: persona.is_public,
is_visible: persona.is_visible,
featured: persona.featured,
builtin_persona: persona.builtin_persona,
display_priority: persona.display_priority,
owner: persona.owner,
groups: persona.groups,
users: persona.users,
uploaded_image_id: persona.uploaded_image_id,
icon_name: persona.icon_name,
};
}
// ---------------------------------------------------------------------------
// Column renderers
// ---------------------------------------------------------------------------
function renderCreatedByColumn(
_value: MinimalUserSnapshot | null,
row: AgentRow
) {
if (row.builtin_persona) {
return (
<Text as="span" mainUiBody text03>
System
</Text>
);
}
return (
<Text as="span" mainUiBody text03>
{row.owner?.email ?? "\u2014"}
</Text>
);
}
function renderAccessColumn(isPublic: boolean) {
return (
<Tag
color={isPublic ? "green" : "gray"}
size="sm"
title={isPublic ? "Public" : "Private"}
/>
);
}
// ---------------------------------------------------------------------------
// Columns
// ---------------------------------------------------------------------------
const tc = createTableColumns<AgentRow>();
function buildColumns(onMutate: () => void) {
return [
tc.displayColumn({
id: "avatar",
cell: (row) => (
<AgentAvatar
agent={row as unknown as MinimalPersonaSnapshot}
size={32}
/>
),
width: { fixed: 48 },
enableHiding: false,
}),
tc.column("name", {
header: "Name",
weight: 25,
minWidth: 150,
cell: (value) => (
<Text as="span" mainUiBody text05>
{value}
</Text>
),
}),
tc.column("description", {
header: "Description",
weight: 35,
minWidth: 200,
cell: (value) => (
<Text as="span" mainUiBody text03>
{value || "\u2014"}
</Text>
),
}),
tc.column("owner", {
header: "Created By",
weight: 20,
minWidth: 140,
cell: renderCreatedByColumn,
}),
tc.column("is_public", {
header: "Access",
weight: 12,
minWidth: 100,
cell: renderAccessColumn,
}),
tc.actions({
cell: (row) => <AgentRowActions agent={row} onMutate={onMutate} />,
}),
];
}
// ---------------------------------------------------------------------------
// Component
// ---------------------------------------------------------------------------
const PAGE_SIZE = 10;
export default function AgentsTable() {
const [searchTerm, setSearchTerm] = useState("");
const { personas, isLoading, error, refresh } = useAdminPersonas();
const columns = useMemo(() => buildColumns(refresh), [refresh]);
const agentRows: AgentRow[] = useMemo(
() => personas.filter((p) => !p.builtin_persona).map(toAgentRow),
[personas]
);
const handleReorder = async (
_orderedIds: string[],
changedOrders: Record<string, number>
) => {
try {
await updateAgentDisplayPriorities(changedOrders);
refresh();
} catch (err) {
toast.error(
err instanceof Error ? err.message : "Failed to update agent order"
);
refresh();
}
};
if (isLoading) {
return (
<div className="flex justify-center py-12">
<SimpleLoader className="h-6 w-6" />
</div>
);
}
if (error) {
return (
<Text as="p" secondaryBody text03>
Failed to load agents. Please try refreshing the page.
</Text>
);
}
return (
<div className="flex flex-col gap-3">
<InputTypeIn
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
placeholder="Search agents..."
leftSearchIcon
/>
<Table
data={agentRows}
columns={columns}
getRowId={(row) => String(row.id)}
pageSize={PAGE_SIZE}
searchTerm={searchTerm}
draggable={{
onReorder: handleReorder,
}}
emptyState={
<IllustrationContent
illustration={SvgNoResult}
title="No agents found"
description="No agents match the current search."
/>
}
footer={{}}
/>
</div>
);
}

View File

@@ -1,17 +0,0 @@
import type { MinimalUserSnapshot } from "@/lib/types";
export interface AgentRow {
id: number;
name: string;
description: string;
is_public: boolean;
is_visible: boolean;
featured: boolean;
builtin_persona: boolean;
display_priority: number | null;
owner: MinimalUserSnapshot | null;
groups: number[];
users: MinimalUserSnapshot[];
uploaded_image_id?: string;
icon_name?: string;
}

View File

@@ -1,68 +0,0 @@
async function parseErrorDetail(
res: Response,
fallback: string
): Promise<string> {
try {
const body = await res.json();
return body?.detail ?? fallback;
} catch {
return fallback;
}
}
export async function deleteAgent(agentId: number): Promise<void> {
const res = await fetch(`/api/persona/${agentId}`, {
method: "DELETE",
credentials: "include",
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to delete agent"));
}
}
export async function toggleAgentFeatured(
agentId: number,
currentlyFeatured: boolean
): Promise<void> {
const res = await fetch(`/api/admin/persona/${agentId}/featured`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ featured: !currentlyFeatured }),
credentials: "include",
});
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to toggle featured status")
);
}
}
export async function toggleAgentVisibility(
agentId: number,
currentlyVisible: boolean
): Promise<void> {
const res = await fetch(`/api/admin/persona/${agentId}/visible`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ is_visible: !currentlyVisible }),
credentials: "include",
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to toggle visibility"));
}
}
export async function updateAgentDisplayPriorities(
displayPriorityMap: Record<string, number>
): Promise<void> {
const res = await fetch("/api/admin/agents/display-priorities", {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ display_priority_map: displayPriorityMap }),
});
if (!res.ok) {
throw new Error(
await parseErrorDetail(res, "Failed to update agent order")
);
}
}

View File

@@ -26,8 +26,7 @@ import type {
StatusFilter,
StatusCountMap,
} from "./interfaces";
import UserAvatar from "@/refresh-components/avatars/UserAvatar";
import type { User } from "@/lib/types";
import { getInitials } from "./utils";
// ---------------------------------------------------------------------------
// Column renderers
@@ -76,25 +75,20 @@ const tc = createTableColumns<UserRow>();
function buildColumns(onMutate: () => void) {
return [
tc.qualifier({
content: "icon",
getContent: (row) => {
const user = {
email: row.email,
personalization: row.personal_name
? { name: row.personal_name }
: undefined,
} as User;
return (props) => <UserAvatar user={user} size={props.size} />;
},
content: "avatar-user",
getInitials: (row) => getInitials(row.personal_name, row.email),
selectable: false,
}),
tc.column("email", {
header: "Name",
weight: 22,
minWidth: 140,
cell: renderNameColumn,
}),
tc.column("groups", {
header: "Groups",
weight: 24,
minWidth: 200,
enableSorting: false,
cell: (value, row) => (
<GroupsCell groups={value} user={row} onMutate={onMutate} />
@@ -103,16 +97,19 @@ function buildColumns(onMutate: () => void) {
tc.column("role", {
header: "Account Type",
weight: 16,
minWidth: 180,
cell: (_value, row) => <UserRoleCell user={row} onMutate={onMutate} />,
}),
tc.column("status", {
header: "Status",
weight: 14,
minWidth: 100,
cell: renderStatusColumn,
}),
tc.column("updated_at", {
header: "Last Updated",
weight: 14,
minWidth: 100,
cell: renderLastUpdatedColumn,
}),
tc.actions({
@@ -222,6 +219,7 @@ export default function UsersTable({
data={filteredUsers}
columns={columns}
getRowId={(row) => row.id ?? row.email}
qualifier="avatar"
pageSize={PAGE_SIZE}
searchTerm={searchTerm}
emptyState={

View File

@@ -0,0 +1,43 @@
import { getInitials } from "./utils";
describe("getInitials", () => {
it("returns first letters of first two name parts", () => {
expect(getInitials("Alice Smith", "alice@example.com")).toBe("AS");
});
it("returns first two chars of a single-word name", () => {
expect(getInitials("Alice", "alice@example.com")).toBe("AL");
});
it("handles three-word names (uses first two)", () => {
expect(getInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
});
it("falls back to email local part with dot separator", () => {
expect(getInitials(null, "alice.smith@example.com")).toBe("AS");
});
it("falls back to email local part with underscore separator", () => {
expect(getInitials(null, "alice_smith@example.com")).toBe("AS");
});
it("falls back to email local part with hyphen separator", () => {
expect(getInitials(null, "alice-smith@example.com")).toBe("AS");
});
it("uses first two chars of email local if no separator", () => {
expect(getInitials(null, "alice@example.com")).toBe("AL");
});
it("returns ? for empty email local part", () => {
expect(getInitials(null, "@example.com")).toBe("?");
});
it("uppercases the result", () => {
expect(getInitials("john doe", "jd@test.com")).toBe("JD");
});
it("trims whitespace from name", () => {
expect(getInitials(" Alice Smith ", "a@test.com")).toBe("AS");
});
});

View File

@@ -0,0 +1,23 @@
/**
* Derive display initials from a user's name or email.
*
* - If a name is provided, uses the first letter of the first two words.
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
* - Returns at most 2 uppercase characters.
*/
export function getInitials(name: string | null, email: string): string {
if (name) {
const parts = name.trim().split(/\s+/);
if (parts.length >= 2) {
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
}
return name.slice(0, 2).toUpperCase();
}
const local = email.split("@")[0];
if (!local) return "?";
const parts = local.split(/[._-]/);
if (parts.length >= 2) {
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
}
return local.slice(0, 2).toUpperCase();
}