mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-17 21:52:45 +00:00
Compare commits
1 Commits
bo/hook-cr
...
bo/hook_ap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
442838d80e |
@@ -1,233 +0,0 @@
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.models import Hook
|
||||
from onyx.db.models import HookExecutionLog
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ── Hook CRUD ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_hook_by_id(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
include_deleted: bool = False,
|
||||
include_creator: bool = False,
|
||||
) -> Hook | None:
|
||||
stmt = select(Hook).where(Hook.id == hook_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Hook.deleted.is_(False))
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def get_non_deleted_hook_by_hook_point(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
include_creator: bool = False,
|
||||
) -> Hook | None:
|
||||
stmt = (
|
||||
select(Hook).where(Hook.hook_point == hook_point).where(Hook.deleted.is_(False))
|
||||
)
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def get_hooks(
|
||||
*,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
include_creator: bool = False,
|
||||
) -> list[Hook]:
|
||||
stmt = select(Hook)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Hook.deleted.is_(False))
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
stmt = stmt.order_by(Hook.hook_point, Hook.created_at.desc())
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def create_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
name: str,
|
||||
hook_point: HookPoint,
|
||||
endpoint_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
fail_strategy: HookFailStrategy,
|
||||
timeout_seconds: float,
|
||||
is_active: bool = False,
|
||||
creator_id: UUID | None = None,
|
||||
) -> Hook:
|
||||
"""Create a new hook for the given hook point.
|
||||
|
||||
At most one non-deleted hook per hook point is allowed. Raises
|
||||
OnyxError(CONFLICT) if a hook already exists, including under concurrent
|
||||
duplicate creates where the partial unique index fires an IntegrityError.
|
||||
"""
|
||||
existing = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if existing:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
f"A hook for '{hook_point.value}' already exists (id={existing.id}).",
|
||||
)
|
||||
|
||||
hook = Hook(
|
||||
name=name,
|
||||
hook_point=hook_point,
|
||||
endpoint_url=endpoint_url,
|
||||
api_key=api_key,
|
||||
fail_strategy=fail_strategy,
|
||||
timeout_seconds=timeout_seconds,
|
||||
is_active=is_active,
|
||||
creator_id=creator_id,
|
||||
)
|
||||
# Use a savepoint so that a failed insert only rolls back this operation,
|
||||
# not the entire outer transaction.
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(hook)
|
||||
savepoint.commit()
|
||||
except IntegrityError as exc:
|
||||
savepoint.rollback()
|
||||
if "ix_hook_one_non_deleted_per_point" in str(exc.orig):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
f"A hook for '{hook_point.value}' already exists.",
|
||||
)
|
||||
raise # re-raise unrelated integrity errors (FK violations, etc.)
|
||||
return hook
|
||||
|
||||
|
||||
def update_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
name: str | None = None,
|
||||
endpoint_url: str | None | UnsetType = UNSET,
|
||||
api_key: str | None | UnsetType = UNSET,
|
||||
fail_strategy: HookFailStrategy | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
is_active: bool | None = None,
|
||||
is_reachable: bool | None = None,
|
||||
include_creator: bool = False,
|
||||
) -> Hook:
|
||||
"""Update hook fields.
|
||||
|
||||
Sentinel conventions:
|
||||
- endpoint_url, api_key: pass UNSET to leave unchanged; pass None to clear.
|
||||
- name, fail_strategy, timeout_seconds, is_active, is_reachable: pass None to leave unchanged.
|
||||
"""
|
||||
hook = get_hook_by_id(
|
||||
db_session=db_session, hook_id=hook_id, include_creator=include_creator
|
||||
)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
|
||||
|
||||
if name is not None:
|
||||
hook.name = name
|
||||
if not isinstance(endpoint_url, UnsetType):
|
||||
hook.endpoint_url = endpoint_url
|
||||
if not isinstance(api_key, UnsetType):
|
||||
hook.api_key = api_key # type: ignore[assignment] # EncryptedString coerces str → SensitiveValue at the ORM level
|
||||
if fail_strategy is not None:
|
||||
hook.fail_strategy = fail_strategy
|
||||
if timeout_seconds is not None:
|
||||
hook.timeout_seconds = timeout_seconds
|
||||
if is_active is not None:
|
||||
hook.is_active = is_active
|
||||
if is_reachable is not None:
|
||||
hook.is_reachable = is_reachable
|
||||
|
||||
db_session.flush()
|
||||
return hook
|
||||
|
||||
|
||||
def delete_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
) -> None:
|
||||
hook = get_hook_by_id(db_session=db_session, hook_id=hook_id)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
|
||||
|
||||
hook.deleted = True
|
||||
hook.is_active = False
|
||||
db_session.flush()
|
||||
|
||||
|
||||
# ── HookExecutionLog CRUD ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_hook_execution_log__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
is_success: bool,
|
||||
error_message: str | None = None,
|
||||
status_code: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> HookExecutionLog:
|
||||
log = HookExecutionLog(
|
||||
hook_id=hook_id,
|
||||
is_success=is_success,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
db_session.add(log)
|
||||
db_session.flush()
|
||||
return log
|
||||
|
||||
|
||||
def get_hook_execution_logs(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
limit: int,
|
||||
) -> list[HookExecutionLog]:
|
||||
stmt = (
|
||||
select(HookExecutionLog)
|
||||
.where(HookExecutionLog.hook_id == hook_id)
|
||||
.order_by(HookExecutionLog.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def cleanup_old_execution_logs__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
max_age_days: int,
|
||||
) -> int:
|
||||
"""Delete execution logs older than max_age_days. Returns the number of rows deleted."""
|
||||
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=max_age_days
|
||||
)
|
||||
result: CursorResult = db_session.execute( # type: ignore[assignment]
|
||||
delete(HookExecutionLog)
|
||||
.where(HookExecutionLog.created_at < cutoff)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
return result.rowcount
|
||||
@@ -501,31 +501,20 @@ def query_vespa(
|
||||
response = http_client.post(SEARCH_ENDPOINT, json=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
response_text = (
|
||||
e.response.text if isinstance(e, httpx.HTTPStatusError) else None
|
||||
)
|
||||
status_code = (
|
||||
e.response.status_code if isinstance(e, httpx.HTTPStatusError) else None
|
||||
)
|
||||
yql_value = params.get("yql", "")
|
||||
yql_length = len(str(yql_value))
|
||||
|
||||
# Log each detail on its own line so log collectors capture them
|
||||
# as separate entries rather than truncating a single multiline msg
|
||||
error_base = "Failed to query Vespa"
|
||||
logger.error(
|
||||
f"Failed to query Vespa | "
|
||||
f"status={status_code} | "
|
||||
f"yql_length={yql_length} | "
|
||||
f"exception={str(e)}"
|
||||
f"{error_base}:\n"
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
+ (
|
||||
f"\nResponse: {e.response.text}"
|
||||
if isinstance(e, httpx.HTTPStatusError)
|
||||
else ""
|
||||
)
|
||||
)
|
||||
if response_text:
|
||||
logger.error(f"Vespa error response: {response_text[:1000]}")
|
||||
logger.error(f"Vespa request URL: {e.request.url}")
|
||||
|
||||
# Re-raise with diagnostics so callers see what actually went wrong
|
||||
raise httpx.HTTPError(
|
||||
f"Failed to query Vespa (status={status_code}, " f"yql_length={yql_length})"
|
||||
) from e
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
|
||||
response_json: dict[str, Any] = response.json()
|
||||
|
||||
|
||||
@@ -43,22 +43,6 @@ def build_vespa_filters(
|
||||
return ""
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_weighted_set_filter(key: str, vals: list[str] | None) -> str:
|
||||
"""Build a Vespa weightedSet filter for large value lists.
|
||||
|
||||
Uses Vespa's native weightedSet() operator instead of OR-chained
|
||||
'contains' clauses. This is critical for fields like
|
||||
access_control_list where a single user may have tens of thousands
|
||||
of ACL entries — OR clauses at that scale cause Vespa to reject
|
||||
the query with HTTP 400."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
filtered = [val for val in vals if val]
|
||||
if not filtered:
|
||||
return ""
|
||||
items = ", ".join(f'"{val}":1' for val in filtered)
|
||||
return f"weightedSet({key}, {{{items}}})"
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
@@ -173,16 +157,11 @@ def build_vespa_filters(
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
|
||||
# ACL filters — use weightedSet for efficient matching against the
|
||||
# access_control_list weightedset<string> field. OR-chaining thousands
|
||||
# of 'contains' clauses causes Vespa to reject the query (HTTP 400)
|
||||
# for users with large numbers of external permission groups.
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_weighted_set_filter(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
),
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
|
||||
127
backend/onyx/hooks/models.py
Normal file
127
backend/onyx/hooks/models.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
from pydantic import SecretStr
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
|
||||
NonEmptySecretStr = Annotated[SecretStr, Field(min_length=1)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookCreateRequest(BaseModel):
|
||||
name: str = Field(min_length=1)
|
||||
hook_point: HookPoint
|
||||
endpoint_url: str = Field(min_length=1)
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = None # if None, uses HookPointSpec default
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None, uses HookPointSpec default
|
||||
|
||||
@field_validator("name", "endpoint_url")
|
||||
@classmethod
|
||||
def no_whitespace_only(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("cannot be whitespace-only.")
|
||||
return v
|
||||
|
||||
|
||||
class HookUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = (
|
||||
None # if None in model_fields_set, reset to spec default
|
||||
)
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None in model_fields_set, reset to spec default
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_field(self) -> "HookUpdateRequest":
|
||||
if not self.model_fields_set:
|
||||
raise ValueError("At least one field must be provided for an update.")
|
||||
if "name" in self.model_fields_set and not (self.name or "").strip():
|
||||
raise ValueError("name cannot be cleared.")
|
||||
if (
|
||||
"endpoint_url" in self.model_fields_set
|
||||
and not (self.endpoint_url or "").strip()
|
||||
):
|
||||
raise ValueError("endpoint_url cannot be cleared.")
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookPointMetaResponse(BaseModel):
|
||||
hook_point: HookPoint
|
||||
display_name: str
|
||||
description: str
|
||||
docs_url: str | None
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
default_timeout_seconds: float
|
||||
default_fail_strategy: HookFailStrategy
|
||||
fail_hard_description: str
|
||||
|
||||
|
||||
class HookResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
hook_point: HookPoint
|
||||
# Nullable to match the DB column — endpoint_url is required on creation but
|
||||
# future hook point types may not use an external endpoint (e.g. built-in handlers).
|
||||
endpoint_url: str | None
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
creator_email: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class HookValidateResponse(BaseModel):
|
||||
success: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookHealthStatus(str, Enum):
|
||||
healthy = "healthy" # green — reachable, no failures in last 1h
|
||||
degraded = "degraded" # yellow — reachable, failures in last 1h
|
||||
unreachable = "unreachable" # red — is_reachable=false or null
|
||||
|
||||
|
||||
class HookFailureRecord(BaseModel):
|
||||
error_message: str | None = None
|
||||
status_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class HookHealthResponse(BaseModel):
|
||||
status: HookHealthStatus
|
||||
recent_failures: list[HookFailureRecord] = Field(
|
||||
default_factory=list,
|
||||
description="Last 10 failures, newest first",
|
||||
max_length=10,
|
||||
)
|
||||
0
backend/onyx/hooks/points/__init__.py
Normal file
0
backend/onyx/hooks/points/__init__.py
Normal file
56
backend/onyx/hooks/points/base.py
Normal file
56
backend/onyx/hooks/points/base.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
|
||||
|
||||
_REQUIRED_ATTRS = (
|
||||
"hook_point",
|
||||
"display_name",
|
||||
"description",
|
||||
"default_timeout_seconds",
|
||||
"fail_hard_description",
|
||||
"default_fail_strategy",
|
||||
)
|
||||
|
||||
|
||||
class HookPointSpec(ABC):
|
||||
"""Static metadata and contract for a pipeline hook point.
|
||||
|
||||
Each hook point is a concrete subclass of this class. Onyx engineers
|
||||
own these definitions — customers never touch this code.
|
||||
|
||||
Instances are registered in onyx.hooks.registry._REGISTRY and served
|
||||
via GET /api/admin/hook-points.
|
||||
|
||||
Subclasses must define all attributes as class-level constants.
|
||||
"""
|
||||
|
||||
hook_point: HookPoint
|
||||
display_name: str
|
||||
description: str
|
||||
default_timeout_seconds: float
|
||||
fail_hard_description: str
|
||||
default_fail_strategy: HookFailStrategy
|
||||
docs_url: str | None = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Skip intermediate abstract subclasses — they may still be partially defined.
|
||||
if getattr(cls, "__abstractmethods__", None):
|
||||
return
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
if missing:
|
||||
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the request payload sent to the customer's endpoint."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
"""JSON schema describing the expected response from the customer's endpoint."""
|
||||
29
backend/onyx/hooks/points/document_ingestion.py
Normal file
29
backend/onyx/hooks/points/document_ingestion.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
|
||||
"""
|
||||
|
||||
hook_point = HookPoint.DOCUMENT_INGESTION
|
||||
display_name = "Document Ingestion"
|
||||
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define input schema
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
# TODO(@Bo-Onyx): define output schema
|
||||
return {"type": "object", "properties": {}}
|
||||
79
backend/onyx/hooks/points/query_processing.py
Normal file
79
backend/onyx/hooks/points/query_processing.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class QueryProcessingSpec(HookPointSpec):
|
||||
"""Hook point that runs on every user query before it enters the pipeline.
|
||||
|
||||
Call site: inside handle_stream_message_objects() in
|
||||
backend/onyx/chat/process_message.py, immediately after message_text is
|
||||
assigned from the request and before create_new_chat_message() saves it.
|
||||
|
||||
This is the earliest possible point in the query pipeline:
|
||||
- Raw query — unmodified, exactly as the user typed it
|
||||
- No side effects yet — message has not been saved to DB
|
||||
- User identity is available for user-specific logic
|
||||
|
||||
Supported use cases:
|
||||
- Query rejection: block queries based on content or user context
|
||||
- Query rewriting: normalize, expand, or modify the query
|
||||
- PII removal: scrub sensitive data before the LLM sees it
|
||||
- Access control: reject queries from certain users or groups
|
||||
- Query auditing: log or track queries based on business rules
|
||||
"""
|
||||
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
display_name = "Query Processing"
|
||||
description = (
|
||||
"Runs on every user query before it enters the pipeline. "
|
||||
"Allows rewriting, filtering, or rejecting queries."
|
||||
)
|
||||
default_timeout_seconds = 5.0 # user is actively waiting — keep tight
|
||||
fail_hard_description = (
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The raw query string exactly as the user typed it.",
|
||||
},
|
||||
"user_email": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Email of the user submitting the query, or null if unauthenticated.",
|
||||
},
|
||||
},
|
||||
"required": ["query", "user_email"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"The (optionally modified) query to use. "
|
||||
"Set to null to reject the query."
|
||||
),
|
||||
},
|
||||
"rejection_message": {
|
||||
"type": ["string", "null"],
|
||||
"description": (
|
||||
"Message shown to the user when query is null. "
|
||||
"Falls back to a generic message if not provided."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
45
backend/onyx/hooks/registry.py
Normal file
45
backend/onyx/hooks/registry.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionSpec
|
||||
from onyx.hooks.points.query_processing import QueryProcessingSpec
|
||||
|
||||
# Internal: use `monkeypatch.setattr(registry_module, "_REGISTRY", {...})` to override in tests.
|
||||
_REGISTRY: dict[HookPoint, HookPointSpec] = {
|
||||
HookPoint.DOCUMENT_INGESTION: DocumentIngestionSpec(),
|
||||
HookPoint.QUERY_PROCESSING: QueryProcessingSpec(),
|
||||
}
|
||||
|
||||
|
||||
def validate_registry() -> None:
|
||||
"""Assert that every HookPoint enum value has a registered spec.
|
||||
|
||||
Call once at application startup (e.g. from the FastAPI lifespan hook).
|
||||
Raises RuntimeError if any hook point is missing a spec.
|
||||
"""
|
||||
missing = set(HookPoint) - set(_REGISTRY)
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"Hook point(s) have no registered spec: {missing}. "
|
||||
"Add an entry to onyx.hooks.registry._REGISTRY."
|
||||
)
|
||||
|
||||
|
||||
def get_hook_point_spec(hook_point: HookPoint) -> HookPointSpec:
|
||||
"""Returns the spec for a given hook point.
|
||||
|
||||
Raises ValueError if the hook point has no registered spec — this is a
|
||||
programmer error; every HookPoint enum value must have a corresponding spec
|
||||
in _REGISTRY.
|
||||
"""
|
||||
try:
|
||||
return _REGISTRY[hook_point]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"No spec registered for hook point {hook_point!r}. "
|
||||
"Add an entry to onyx.hooks.registry._REGISTRY."
|
||||
)
|
||||
|
||||
|
||||
def get_all_specs() -> list[HookPointSpec]:
|
||||
"""Returns the specs for all registered hook points."""
|
||||
return list(_REGISTRY.values())
|
||||
@@ -62,6 +62,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.hooks.registry import validate_registry
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
@@ -308,6 +309,7 @@ def validate_no_vector_db_settings() -> None:
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
validate_registry()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
|
||||
22
backend/tests/unit/onyx/hooks/test_base_spec.py
Normal file
22
backend/tests/unit/onyx/hooks/test_base_spec.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
def test_init_subclass_raises_for_missing_attrs() -> None:
|
||||
with pytest.raises(TypeError, match="must define class attributes"):
|
||||
|
||||
class IncompleteSpec(HookPointSpec):
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
# missing display_name, description, etc.
|
||||
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return {}
|
||||
86
backend/tests/unit/onyx/hooks/test_models.py
Normal file
86
backend/tests/unit/onyx/hooks/test_models.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.models import HookCreateRequest
|
||||
from onyx.hooks.models import HookUpdateRequest
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_empty() -> None:
|
||||
# No fields supplied at all
|
||||
with pytest.raises(ValidationError, match="At least one field must be provided"):
|
||||
HookUpdateRequest()
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_null_name_when_only_field() -> None:
|
||||
# Explicitly setting name=None is rejected as name cannot be cleared
|
||||
with pytest.raises(ValidationError, match="name cannot be cleared"):
|
||||
HookUpdateRequest(name=None)
|
||||
|
||||
|
||||
def test_hook_update_request_accepts_single_field() -> None:
|
||||
req = HookUpdateRequest(name="new name")
|
||||
assert req.name == "new name"
|
||||
|
||||
|
||||
def test_hook_update_request_accepts_partial_fields() -> None:
|
||||
req = HookUpdateRequest(fail_strategy=HookFailStrategy.SOFT, timeout_seconds=10.0)
|
||||
assert req.fail_strategy == HookFailStrategy.SOFT
|
||||
assert req.timeout_seconds == 10.0
|
||||
assert req.name is None
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_null_name() -> None:
|
||||
with pytest.raises(ValidationError, match="name cannot be cleared"):
|
||||
HookUpdateRequest(name=None, fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_empty_name() -> None:
|
||||
with pytest.raises(ValidationError, match="name cannot be cleared"):
|
||||
HookUpdateRequest(name="", fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_null_endpoint_url() -> None:
|
||||
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
|
||||
HookUpdateRequest(endpoint_url=None, fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_empty_endpoint_url() -> None:
|
||||
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
|
||||
HookUpdateRequest(endpoint_url="", fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_allows_null_api_key() -> None:
|
||||
# api_key=null is valid — means "clear the api key"
|
||||
req = HookUpdateRequest(api_key=None)
|
||||
assert req.api_key is None
|
||||
assert "api_key" in req.model_fields_set
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_whitespace_name() -> None:
|
||||
with pytest.raises(ValidationError, match="name cannot be cleared"):
|
||||
HookUpdateRequest(name=" ", fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_update_request_rejects_whitespace_endpoint_url() -> None:
|
||||
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
|
||||
HookUpdateRequest(endpoint_url=" ", fail_strategy=HookFailStrategy.SOFT)
|
||||
|
||||
|
||||
def test_hook_create_request_rejects_whitespace_name() -> None:
|
||||
with pytest.raises(ValidationError, match="whitespace-only"):
|
||||
HookCreateRequest(
|
||||
name=" ",
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
endpoint_url="https://example.com/hook",
|
||||
)
|
||||
|
||||
|
||||
def test_hook_create_request_rejects_whitespace_endpoint_url() -> None:
|
||||
with pytest.raises(ValidationError, match="whitespace-only"):
|
||||
HookCreateRequest(
|
||||
name="my hook",
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
endpoint_url=" ",
|
||||
)
|
||||
54
backend/tests/unit/onyx/hooks/test_query_processing_spec.py
Normal file
54
backend/tests/unit/onyx/hooks/test_query_processing_spec.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.query_processing import QueryProcessingSpec
|
||||
|
||||
|
||||
def test_hook_point_is_query_processing() -> None:
|
||||
assert QueryProcessingSpec().hook_point == HookPoint.QUERY_PROCESSING
|
||||
|
||||
|
||||
def test_default_fail_strategy_is_hard() -> None:
|
||||
assert QueryProcessingSpec().default_fail_strategy == HookFailStrategy.HARD
|
||||
|
||||
|
||||
def test_default_timeout_seconds() -> None:
|
||||
# User is actively waiting — 5s is the documented contract for this hook point
|
||||
assert QueryProcessingSpec().default_timeout_seconds == 5.0
|
||||
|
||||
|
||||
def test_input_schema_required_fields() -> None:
|
||||
schema = QueryProcessingSpec().input_schema
|
||||
assert schema["type"] == "object"
|
||||
required = schema["required"]
|
||||
assert "query" in required
|
||||
assert "user_email" in required
|
||||
|
||||
|
||||
def test_input_schema_query_is_string() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
assert props["query"]["type"] == "string"
|
||||
|
||||
|
||||
def test_input_schema_user_email_is_nullable() -> None:
|
||||
props = QueryProcessingSpec().input_schema["properties"]
|
||||
assert "null" in props["user_email"]["type"]
|
||||
|
||||
|
||||
def test_output_schema_query_is_required() -> None:
|
||||
schema = QueryProcessingSpec().output_schema
|
||||
assert "query" in schema["required"]
|
||||
|
||||
|
||||
def test_output_schema_query_is_nullable() -> None:
|
||||
# null means "reject the query"
|
||||
props = QueryProcessingSpec().output_schema["properties"]
|
||||
assert "null" in props["query"]["type"]
|
||||
|
||||
|
||||
def test_output_schema_rejection_message_is_optional() -> None:
|
||||
schema = QueryProcessingSpec().output_schema
|
||||
assert "rejection_message" not in schema.get("required", [])
|
||||
|
||||
|
||||
def test_input_schema_no_additional_properties() -> None:
|
||||
assert QueryProcessingSpec().input_schema.get("additionalProperties") is False
|
||||
47
backend/tests/unit/onyx/hooks/test_registry.py
Normal file
47
backend/tests/unit/onyx/hooks/test_registry.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks import registry as registry_module
|
||||
from onyx.hooks.registry import get_all_specs
|
||||
from onyx.hooks.registry import get_hook_point_spec
|
||||
from onyx.hooks.registry import validate_registry
|
||||
|
||||
|
||||
def test_registry_covers_all_hook_points() -> None:
|
||||
"""Every HookPoint enum member must have a registered spec."""
|
||||
assert {s.hook_point for s in get_all_specs()} == set(
|
||||
HookPoint
|
||||
), f"Missing specs for: {set(HookPoint) - {s.hook_point for s in get_all_specs()}}"
|
||||
|
||||
|
||||
def test_get_hook_point_spec_returns_correct_spec() -> None:
|
||||
for hook_point in HookPoint:
|
||||
spec = get_hook_point_spec(hook_point)
|
||||
assert spec.hook_point == hook_point
|
||||
|
||||
|
||||
def test_get_all_specs_returns_all() -> None:
|
||||
specs = get_all_specs()
|
||||
assert len(specs) == len(HookPoint)
|
||||
assert {s.hook_point for s in specs} == set(HookPoint)
|
||||
|
||||
|
||||
def test_get_hook_point_spec_raises_for_unregistered(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""get_hook_point_spec raises ValueError when a hook point has no spec."""
|
||||
monkeypatch.setattr(registry_module, "_REGISTRY", {})
|
||||
with pytest.raises(ValueError, match="No spec registered for hook point"):
|
||||
get_hook_point_spec(HookPoint.QUERY_PROCESSING)
|
||||
|
||||
|
||||
def test_validate_registry_passes() -> None:
|
||||
validate_registry() # should not raise with the real registry
|
||||
|
||||
|
||||
def test_validate_registry_raises_for_incomplete(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(registry_module, "_REGISTRY", {})
|
||||
with pytest.raises(RuntimeError, match="Hook point\\(s\\) have no registered spec"):
|
||||
validate_registry()
|
||||
@@ -44,13 +44,13 @@ class TestBuildVespaFilters:
|
||||
assert result == f'({SOURCE_TYPE} contains "web") and '
|
||||
|
||||
def test_acl(self) -> None:
|
||||
"""Test with acls — uses weightedSet operator for efficient matching."""
|
||||
"""Test with acls."""
|
||||
# Single ACL
|
||||
filters = IndexFilters(access_control_list=["user1"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
result
|
||||
== f'!({HIDDEN}=true) and weightedSet(access_control_list, {{"user1":1}}) and '
|
||||
== f'!({HIDDEN}=true) and (access_control_list contains "user1") and '
|
||||
)
|
||||
|
||||
# Multiple ACL's
|
||||
@@ -58,7 +58,7 @@ class TestBuildVespaFilters:
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
result
|
||||
== f'!({HIDDEN}=true) and weightedSet(access_control_list, {{"user2":1, "group2":1}}) and '
|
||||
== f'!({HIDDEN}=true) and (access_control_list contains "user2" or access_control_list contains "group2") and '
|
||||
)
|
||||
|
||||
def test_tenant_filter(self) -> None:
|
||||
@@ -250,7 +250,7 @@ class TestBuildVespaFilters:
|
||||
result = build_vespa_filters(filters)
|
||||
|
||||
expected = f"!({HIDDEN}=true) and "
|
||||
expected += 'weightedSet(access_control_list, {"user1":1, "group1":1}) and '
|
||||
expected += '(access_control_list contains "user1" or access_control_list contains "group1") and '
|
||||
expected += f'({SOURCE_TYPE} contains "web") and '
|
||||
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
|
||||
# Knowledge scope filters are OR'd together
|
||||
@@ -290,38 +290,6 @@ class TestBuildVespaFilters:
|
||||
expected = f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "engineering") or ({DOCUMENT_ID} contains "{str(id1)}")) and '
|
||||
assert expected == result
|
||||
|
||||
def test_acl_large_list_uses_weighted_set(self) -> None:
|
||||
"""Verify that large ACL lists produce a weightedSet clause
|
||||
instead of OR-chained contains — this is what prevents Vespa
|
||||
HTTP 400 errors for users with thousands of permission groups."""
|
||||
acl = [f"external_group:google_drive_{i}" for i in range(10_000)]
|
||||
acl += ["user_email:user@example.com", "__PUBLIC__"]
|
||||
filters = IndexFilters(access_control_list=acl)
|
||||
result = build_vespa_filters(filters)
|
||||
|
||||
assert "weightedSet(access_control_list, {" in result
|
||||
# Must NOT contain OR-chained contains clauses
|
||||
assert "access_control_list contains" not in result
|
||||
# All entries should be present
|
||||
assert '"external_group:google_drive_0":1' in result
|
||||
assert '"external_group:google_drive_9999":1' in result
|
||||
assert '"user_email:user@example.com":1' in result
|
||||
assert '"__PUBLIC__":1' in result
|
||||
|
||||
def test_acl_empty_strings_filtered(self) -> None:
|
||||
"""Empty strings in the ACL list should be filtered out."""
|
||||
filters = IndexFilters(access_control_list=["user1", "", "group1"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
result
|
||||
== f'!({HIDDEN}=true) and weightedSet(access_control_list, {{"user1":1, "group1":1}}) and '
|
||||
)
|
||||
|
||||
# All empty
|
||||
filters = IndexFilters(access_control_list=["", ""])
|
||||
result = build_vespa_filters(filters)
|
||||
assert result == f"!({HIDDEN}=true) and "
|
||||
|
||||
def test_empty_or_none_values(self) -> None:
|
||||
"""Test with empty or None values in filter lists."""
|
||||
# Empty strings in document set
|
||||
|
||||
Reference in New Issue
Block a user