Compare commits

..

1 Commits

Author SHA1 Message Date
Bo-Onyx
442838d80e chore(hooks): Define Hook Point 2026-03-17 13:58:20 -07:00
15 changed files with 565 additions and 315 deletions

View File

@@ -1,233 +0,0 @@
import datetime
from uuid import UUID
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.constants import UNSET
from onyx.db.constants import UnsetType
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.db.models import Hook
from onyx.db.models import HookExecutionLog
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ── Hook CRUD ────────────────────────────────────────────────────────────
def get_hook_by_id(
*,
db_session: Session,
hook_id: int,
include_deleted: bool = False,
include_creator: bool = False,
) -> Hook | None:
stmt = select(Hook).where(Hook.id == hook_id)
if not include_deleted:
stmt = stmt.where(Hook.deleted.is_(False))
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
return db_session.scalar(stmt)
def get_non_deleted_hook_by_hook_point(
*,
db_session: Session,
hook_point: HookPoint,
include_creator: bool = False,
) -> Hook | None:
stmt = (
select(Hook).where(Hook.hook_point == hook_point).where(Hook.deleted.is_(False))
)
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
return db_session.scalar(stmt)
def get_hooks(
*,
db_session: Session,
include_deleted: bool = False,
include_creator: bool = False,
) -> list[Hook]:
stmt = select(Hook)
if not include_deleted:
stmt = stmt.where(Hook.deleted.is_(False))
if include_creator:
stmt = stmt.options(selectinload(Hook.creator))
stmt = stmt.order_by(Hook.hook_point, Hook.created_at.desc())
return list(db_session.scalars(stmt).all())
def create_hook__no_commit(
*,
db_session: Session,
name: str,
hook_point: HookPoint,
endpoint_url: str | None = None,
api_key: str | None = None,
fail_strategy: HookFailStrategy,
timeout_seconds: float,
is_active: bool = False,
creator_id: UUID | None = None,
) -> Hook:
"""Create a new hook for the given hook point.
At most one non-deleted hook per hook point is allowed. Raises
OnyxError(CONFLICT) if a hook already exists, including under concurrent
duplicate creates where the partial unique index fires an IntegrityError.
"""
existing = get_non_deleted_hook_by_hook_point(
db_session=db_session, hook_point=hook_point
)
if existing:
raise OnyxError(
OnyxErrorCode.CONFLICT,
f"A hook for '{hook_point.value}' already exists (id={existing.id}).",
)
hook = Hook(
name=name,
hook_point=hook_point,
endpoint_url=endpoint_url,
api_key=api_key,
fail_strategy=fail_strategy,
timeout_seconds=timeout_seconds,
is_active=is_active,
creator_id=creator_id,
)
# Use a savepoint so that a failed insert only rolls back this operation,
# not the entire outer transaction.
savepoint = db_session.begin_nested()
try:
db_session.add(hook)
savepoint.commit()
except IntegrityError as exc:
savepoint.rollback()
if "ix_hook_one_non_deleted_per_point" in str(exc.orig):
raise OnyxError(
OnyxErrorCode.CONFLICT,
f"A hook for '{hook_point.value}' already exists.",
)
raise # re-raise unrelated integrity errors (FK violations, etc.)
return hook
def update_hook__no_commit(
*,
db_session: Session,
hook_id: int,
name: str | None = None,
endpoint_url: str | None | UnsetType = UNSET,
api_key: str | None | UnsetType = UNSET,
fail_strategy: HookFailStrategy | None = None,
timeout_seconds: float | None = None,
is_active: bool | None = None,
is_reachable: bool | None = None,
include_creator: bool = False,
) -> Hook:
"""Update hook fields.
Sentinel conventions:
- endpoint_url, api_key: pass UNSET to leave unchanged; pass None to clear.
- name, fail_strategy, timeout_seconds, is_active, is_reachable: pass None to leave unchanged.
"""
hook = get_hook_by_id(
db_session=db_session, hook_id=hook_id, include_creator=include_creator
)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
if name is not None:
hook.name = name
if not isinstance(endpoint_url, UnsetType):
hook.endpoint_url = endpoint_url
if not isinstance(api_key, UnsetType):
hook.api_key = api_key # type: ignore[assignment] # EncryptedString coerces str → SensitiveValue at the ORM level
if fail_strategy is not None:
hook.fail_strategy = fail_strategy
if timeout_seconds is not None:
hook.timeout_seconds = timeout_seconds
if is_active is not None:
hook.is_active = is_active
if is_reachable is not None:
hook.is_reachable = is_reachable
db_session.flush()
return hook
def delete_hook__no_commit(
*,
db_session: Session,
hook_id: int,
) -> None:
hook = get_hook_by_id(db_session=db_session, hook_id=hook_id)
if hook is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
hook.deleted = True
hook.is_active = False
db_session.flush()
# ── HookExecutionLog CRUD ────────────────────────────────────────────────
def create_hook_execution_log__no_commit(
*,
db_session: Session,
hook_id: int,
is_success: bool,
error_message: str | None = None,
status_code: int | None = None,
duration_ms: int | None = None,
) -> HookExecutionLog:
log = HookExecutionLog(
hook_id=hook_id,
is_success=is_success,
error_message=error_message,
status_code=status_code,
duration_ms=duration_ms,
)
db_session.add(log)
db_session.flush()
return log
def get_hook_execution_logs(
*,
db_session: Session,
hook_id: int,
limit: int,
) -> list[HookExecutionLog]:
stmt = (
select(HookExecutionLog)
.where(HookExecutionLog.hook_id == hook_id)
.order_by(HookExecutionLog.created_at.desc())
.limit(limit)
)
return list(db_session.scalars(stmt).all())
def cleanup_old_execution_logs__no_commit(
*,
db_session: Session,
max_age_days: int,
) -> int:
"""Delete execution logs older than max_age_days. Returns the number of rows deleted."""
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
days=max_age_days
)
result: CursorResult = db_session.execute( # type: ignore[assignment]
delete(HookExecutionLog)
.where(HookExecutionLog.created_at < cutoff)
.execution_options(synchronize_session=False)
)
return result.rowcount

View File

@@ -501,31 +501,20 @@ def query_vespa(
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
except httpx.HTTPError as e:
response_text = (
e.response.text if isinstance(e, httpx.HTTPStatusError) else None
)
status_code = (
e.response.status_code if isinstance(e, httpx.HTTPStatusError) else None
)
yql_value = params.get("yql", "")
yql_length = len(str(yql_value))
# Log each detail on its own line so log collectors capture them
# as separate entries rather than truncating a single multiline msg
error_base = "Failed to query Vespa"
logger.error(
f"Failed to query Vespa | "
f"status={status_code} | "
f"yql_length={yql_length} | "
f"exception={str(e)}"
f"{error_base}:\n"
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
f"Exception: {str(e)}"
+ (
f"\nResponse: {e.response.text}"
if isinstance(e, httpx.HTTPStatusError)
else ""
)
)
if response_text:
logger.error(f"Vespa error response: {response_text[:1000]}")
logger.error(f"Vespa request URL: {e.request.url}")
# Re-raise with diagnostics so callers see what actually went wrong
raise httpx.HTTPError(
f"Failed to query Vespa (status={status_code}, " f"yql_length={yql_length})"
) from e
raise httpx.HTTPError(error_base) from e
response_json: dict[str, Any] = response.json()

View File

@@ -43,22 +43,6 @@ def build_vespa_filters(
return ""
return f"({' or '.join(eq_elems)})"
def _build_weighted_set_filter(key: str, vals: list[str] | None) -> str:
"""Build a Vespa weightedSet filter for large value lists.
Uses Vespa's native weightedSet() operator instead of OR-chained
'contains' clauses. This is critical for fields like
access_control_list where a single user may have tens of thousands
of ACL entries — OR clauses at that scale cause Vespa to reject
the query with HTTP 400."""
if not key or not vals:
return ""
filtered = [val for val in vals if val]
if not filtered:
return ""
items = ", ".join(f'"{val}":1' for val in filtered)
return f"weightedSet({key}, {{{items}}})"
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
"""For an integer field filter.
Returns a bare clause or ""."""
@@ -173,16 +157,11 @@ def build_vespa_filters(
if filters.tenant_id and MULTI_TENANT:
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
# ACL filters — use weightedSet for efficient matching against the
# access_control_list weightedset<string> field. OR-chaining thousands
# of 'contains' clauses causes Vespa to reject the query (HTTP 400)
# for users with large numbers of external permission groups.
# ACL filters
if filters.access_control_list is not None:
_append(
filter_parts,
_build_weighted_set_filter(
ACCESS_CONTROL_LIST, filters.access_control_list
),
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
)
# Source type filters

View File

@@ -0,0 +1,127 @@
from datetime import datetime
from enum import Enum
from typing import Annotated
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import SecretStr
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
NonEmptySecretStr = Annotated[SecretStr, Field(min_length=1)]
# ---------------------------------------------------------------------------
# Request models
# ---------------------------------------------------------------------------
class HookCreateRequest(BaseModel):
name: str = Field(min_length=1)
hook_point: HookPoint
endpoint_url: str = Field(min_length=1)
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = None # if None, uses HookPointSpec default
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None, uses HookPointSpec default
@field_validator("name", "endpoint_url")
@classmethod
def no_whitespace_only(cls, v: str) -> str:
if not v.strip():
raise ValueError("cannot be whitespace-only.")
return v
class HookUpdateRequest(BaseModel):
name: str | None = None
endpoint_url: str | None = None
api_key: NonEmptySecretStr | None = None
fail_strategy: HookFailStrategy | None = (
None # if None in model_fields_set, reset to spec default
)
timeout_seconds: float | None = Field(
default=None, gt=0
) # if None in model_fields_set, reset to spec default
@model_validator(mode="after")
def require_at_least_one_field(self) -> "HookUpdateRequest":
if not self.model_fields_set:
raise ValueError("At least one field must be provided for an update.")
if "name" in self.model_fields_set and not (self.name or "").strip():
raise ValueError("name cannot be cleared.")
if (
"endpoint_url" in self.model_fields_set
and not (self.endpoint_url or "").strip()
):
raise ValueError("endpoint_url cannot be cleared.")
return self
# ---------------------------------------------------------------------------
# Response models
# ---------------------------------------------------------------------------
class HookPointMetaResponse(BaseModel):
hook_point: HookPoint
display_name: str
description: str
docs_url: str | None
input_schema: dict[str, Any]
output_schema: dict[str, Any]
default_timeout_seconds: float
default_fail_strategy: HookFailStrategy
fail_hard_description: str
class HookResponse(BaseModel):
id: int
name: str
hook_point: HookPoint
# Nullable to match the DB column — endpoint_url is required on creation but
# future hook point types may not use an external endpoint (e.g. built-in handlers).
endpoint_url: str | None
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool
creator_email: str | None
created_at: datetime
updated_at: datetime
class HookValidateResponse(BaseModel):
success: bool
error_message: str | None = None
# ---------------------------------------------------------------------------
# Health models
# ---------------------------------------------------------------------------
class HookHealthStatus(str, Enum):
healthy = "healthy" # green — reachable, no failures in last 1h
degraded = "degraded" # yellow — reachable, failures in last 1h
unreachable = "unreachable" # red — is_reachable=false or null
class HookFailureRecord(BaseModel):
error_message: str | None = None
status_code: int | None = None
duration_ms: int | None = None
created_at: datetime
class HookHealthResponse(BaseModel):
status: HookHealthStatus
recent_failures: list[HookFailureRecord] = Field(
default_factory=list,
description="Last 10 failures, newest first",
max_length=10,
)

View File

View File

@@ -0,0 +1,56 @@
from abc import ABC
from abc import abstractmethod
from typing import Any
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
_REQUIRED_ATTRS = (
"hook_point",
"display_name",
"description",
"default_timeout_seconds",
"fail_hard_description",
"default_fail_strategy",
)
class HookPointSpec(ABC):
"""Static metadata and contract for a pipeline hook point.
Each hook point is a concrete subclass of this class. Onyx engineers
own these definitions — customers never touch this code.
Instances are registered in onyx.hooks.registry._REGISTRY and served
via GET /api/admin/hook-points.
Subclasses must define all attributes as class-level constants.
"""
hook_point: HookPoint
display_name: str
description: str
default_timeout_seconds: float
fail_hard_description: str
default_fail_strategy: HookFailStrategy
docs_url: str | None = None
def __init_subclass__(cls, **kwargs: object) -> None:
super().__init_subclass__(**kwargs)
# Skip intermediate abstract subclasses — they may still be partially defined.
if getattr(cls, "__abstractmethods__", None):
return
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
if missing:
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
@property
@abstractmethod
def input_schema(self) -> dict[str, Any]:
"""JSON schema describing the request payload sent to the customer's endpoint."""
@property
@abstractmethod
def output_schema(self) -> dict[str, Any]:
"""JSON schema describing the expected response from the customer's endpoint."""

View File

@@ -0,0 +1,29 @@
from typing import Any
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class DocumentIngestionSpec(HookPointSpec):
"""Hook point that runs during document ingestion.
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
"""
hook_point = HookPoint.DOCUMENT_INGESTION
display_name = "Document Ingestion"
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
default_timeout_seconds = 30.0
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
@property
def input_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define input schema
return {"type": "object", "properties": {}}
@property
def output_schema(self) -> dict[str, Any]:
# TODO(@Bo-Onyx): define output schema
return {"type": "object", "properties": {}}

View File

@@ -0,0 +1,79 @@
from typing import Any
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
class QueryProcessingSpec(HookPointSpec):
"""Hook point that runs on every user query before it enters the pipeline.
Call site: inside handle_stream_message_objects() in
backend/onyx/chat/process_message.py, immediately after message_text is
assigned from the request and before create_new_chat_message() saves it.
This is the earliest possible point in the query pipeline:
- Raw query — unmodified, exactly as the user typed it
- No side effects yet — message has not been saved to DB
- User identity is available for user-specific logic
Supported use cases:
- Query rejection: block queries based on content or user context
- Query rewriting: normalize, expand, or modify the query
- PII removal: scrub sensitive data before the LLM sees it
- Access control: reject queries from certain users or groups
- Query auditing: log or track queries based on business rules
"""
hook_point = HookPoint.QUERY_PROCESSING
display_name = "Query Processing"
description = (
"Runs on every user query before it enters the pipeline. "
"Allows rewriting, filtering, or rejecting queries."
)
default_timeout_seconds = 5.0 # user is actively waiting — keep tight
fail_hard_description = (
"The query will be blocked and the user will see an error message."
)
default_fail_strategy = HookFailStrategy.HARD
@property
def input_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The raw query string exactly as the user typed it.",
},
"user_email": {
"type": ["string", "null"],
"description": "Email of the user submitting the query, or null if unauthenticated.",
},
},
"required": ["query", "user_email"],
"additionalProperties": False,
}
@property
def output_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": ["string", "null"],
"description": (
"The (optionally modified) query to use. "
"Set to null to reject the query."
),
},
"rejection_message": {
"type": ["string", "null"],
"description": (
"Message shown to the user when query is null. "
"Falls back to a generic message if not provided."
),
},
},
"required": ["query"],
}

View File

@@ -0,0 +1,45 @@
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
from onyx.hooks.points.document_ingestion import DocumentIngestionSpec
from onyx.hooks.points.query_processing import QueryProcessingSpec
# Internal: use `monkeypatch.setattr(registry_module, "_REGISTRY", {...})` to override in tests.
_REGISTRY: dict[HookPoint, HookPointSpec] = {
HookPoint.DOCUMENT_INGESTION: DocumentIngestionSpec(),
HookPoint.QUERY_PROCESSING: QueryProcessingSpec(),
}
def validate_registry() -> None:
"""Assert that every HookPoint enum value has a registered spec.
Call once at application startup (e.g. from the FastAPI lifespan hook).
Raises RuntimeError if any hook point is missing a spec.
"""
missing = set(HookPoint) - set(_REGISTRY)
if missing:
raise RuntimeError(
f"Hook point(s) have no registered spec: {missing}. "
"Add an entry to onyx.hooks.registry._REGISTRY."
)
def get_hook_point_spec(hook_point: HookPoint) -> HookPointSpec:
"""Returns the spec for a given hook point.
Raises ValueError if the hook point has no registered spec — this is a
programmer error; every HookPoint enum value must have a corresponding spec
in _REGISTRY.
"""
try:
return _REGISTRY[hook_point]
except KeyError:
raise ValueError(
f"No spec registered for hook point {hook_point!r}. "
"Add an entry to onyx.hooks.registry._REGISTRY."
)
def get_all_specs() -> list[HookPointSpec]:
"""Returns the specs for all registered hook points."""
return list(_REGISTRY.values())

View File

@@ -62,6 +62,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.error_handling.exceptions import register_onyx_exception_handlers
from onyx.file_store.file_store import get_default_file_store
from onyx.hooks.registry import validate_registry
from onyx.server.api_key.api import router as api_key_router
from onyx.server.auth_check import check_router_auth
from onyx.server.documents.cc_pair import router as cc_pair_router
@@ -308,6 +309,7 @@ def validate_no_vector_db_settings() -> None:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
validate_no_vector_db_settings()
validate_cache_backend_settings()
validate_registry()
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:

View File

@@ -0,0 +1,22 @@
from typing import Any
import pytest
from onyx.db.enums import HookPoint
from onyx.hooks.points.base import HookPointSpec
def test_init_subclass_raises_for_missing_attrs() -> None:
with pytest.raises(TypeError, match="must define class attributes"):
class IncompleteSpec(HookPointSpec):
hook_point = HookPoint.QUERY_PROCESSING
# missing display_name, description, etc.
@property
def input_schema(self) -> dict[str, Any]:
return {}
@property
def output_schema(self) -> dict[str, Any]:
return {}

View File

@@ -0,0 +1,86 @@
import pytest
from pydantic import ValidationError
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.models import HookCreateRequest
from onyx.hooks.models import HookUpdateRequest
def test_hook_update_request_rejects_empty() -> None:
# No fields supplied at all
with pytest.raises(ValidationError, match="At least one field must be provided"):
HookUpdateRequest()
def test_hook_update_request_rejects_null_name_when_only_field() -> None:
# Explicitly setting name=None is rejected as name cannot be cleared
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=None)
def test_hook_update_request_accepts_single_field() -> None:
req = HookUpdateRequest(name="new name")
assert req.name == "new name"
def test_hook_update_request_accepts_partial_fields() -> None:
req = HookUpdateRequest(fail_strategy=HookFailStrategy.SOFT, timeout_seconds=10.0)
assert req.fail_strategy == HookFailStrategy.SOFT
assert req.timeout_seconds == 10.0
assert req.name is None
def test_hook_update_request_rejects_null_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=None, fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_empty_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name="", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_null_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url=None, fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_empty_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url="", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_allows_null_api_key() -> None:
# api_key=null is valid — means "clear the api key"
req = HookUpdateRequest(api_key=None)
assert req.api_key is None
assert "api_key" in req.model_fields_set
def test_hook_update_request_rejects_whitespace_name() -> None:
with pytest.raises(ValidationError, match="name cannot be cleared"):
HookUpdateRequest(name=" ", fail_strategy=HookFailStrategy.SOFT)
def test_hook_update_request_rejects_whitespace_endpoint_url() -> None:
with pytest.raises(ValidationError, match="endpoint_url cannot be cleared"):
HookUpdateRequest(endpoint_url=" ", fail_strategy=HookFailStrategy.SOFT)
def test_hook_create_request_rejects_whitespace_name() -> None:
with pytest.raises(ValidationError, match="whitespace-only"):
HookCreateRequest(
name=" ",
hook_point=HookPoint.QUERY_PROCESSING,
endpoint_url="https://example.com/hook",
)
def test_hook_create_request_rejects_whitespace_endpoint_url() -> None:
with pytest.raises(ValidationError, match="whitespace-only"):
HookCreateRequest(
name="my hook",
hook_point=HookPoint.QUERY_PROCESSING,
endpoint_url=" ",
)

View File

@@ -0,0 +1,54 @@
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
from onyx.hooks.points.query_processing import QueryProcessingSpec
def test_hook_point_is_query_processing() -> None:
assert QueryProcessingSpec().hook_point == HookPoint.QUERY_PROCESSING
def test_default_fail_strategy_is_hard() -> None:
assert QueryProcessingSpec().default_fail_strategy == HookFailStrategy.HARD
def test_default_timeout_seconds() -> None:
# User is actively waiting — 5s is the documented contract for this hook point
assert QueryProcessingSpec().default_timeout_seconds == 5.0
def test_input_schema_required_fields() -> None:
schema = QueryProcessingSpec().input_schema
assert schema["type"] == "object"
required = schema["required"]
assert "query" in required
assert "user_email" in required
def test_input_schema_query_is_string() -> None:
props = QueryProcessingSpec().input_schema["properties"]
assert props["query"]["type"] == "string"
def test_input_schema_user_email_is_nullable() -> None:
props = QueryProcessingSpec().input_schema["properties"]
assert "null" in props["user_email"]["type"]
def test_output_schema_query_is_required() -> None:
schema = QueryProcessingSpec().output_schema
assert "query" in schema["required"]
def test_output_schema_query_is_nullable() -> None:
# null means "reject the query"
props = QueryProcessingSpec().output_schema["properties"]
assert "null" in props["query"]["type"]
def test_output_schema_rejection_message_is_optional() -> None:
schema = QueryProcessingSpec().output_schema
assert "rejection_message" not in schema.get("required", [])
def test_input_schema_no_additional_properties() -> None:
assert QueryProcessingSpec().input_schema.get("additionalProperties") is False

View File

@@ -0,0 +1,47 @@
import pytest
from onyx.db.enums import HookPoint
from onyx.hooks import registry as registry_module
from onyx.hooks.registry import get_all_specs
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.registry import validate_registry
def test_registry_covers_all_hook_points() -> None:
"""Every HookPoint enum member must have a registered spec."""
assert {s.hook_point for s in get_all_specs()} == set(
HookPoint
), f"Missing specs for: {set(HookPoint) - {s.hook_point for s in get_all_specs()}}"
def test_get_hook_point_spec_returns_correct_spec() -> None:
for hook_point in HookPoint:
spec = get_hook_point_spec(hook_point)
assert spec.hook_point == hook_point
def test_get_all_specs_returns_all() -> None:
specs = get_all_specs()
assert len(specs) == len(HookPoint)
assert {s.hook_point for s in specs} == set(HookPoint)
def test_get_hook_point_spec_raises_for_unregistered(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""get_hook_point_spec raises ValueError when a hook point has no spec."""
monkeypatch.setattr(registry_module, "_REGISTRY", {})
with pytest.raises(ValueError, match="No spec registered for hook point"):
get_hook_point_spec(HookPoint.QUERY_PROCESSING)
def test_validate_registry_passes() -> None:
validate_registry() # should not raise with the real registry
def test_validate_registry_raises_for_incomplete(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(registry_module, "_REGISTRY", {})
with pytest.raises(RuntimeError, match="Hook point\\(s\\) have no registered spec"):
validate_registry()

View File

@@ -44,13 +44,13 @@ class TestBuildVespaFilters:
assert result == f'({SOURCE_TYPE} contains "web") and '
def test_acl(self) -> None:
"""Test with acls — uses weightedSet operator for efficient matching."""
"""Test with acls."""
# Single ACL
filters = IndexFilters(access_control_list=["user1"])
result = build_vespa_filters(filters)
assert (
result
== f'!({HIDDEN}=true) and weightedSet(access_control_list, {{"user1":1}}) and '
== f'!({HIDDEN}=true) and (access_control_list contains "user1") and '
)
# Multiple ACL's
@@ -58,7 +58,7 @@ class TestBuildVespaFilters:
result = build_vespa_filters(filters)
assert (
result
== f'!({HIDDEN}=true) and weightedSet(access_control_list, {{"user2":1, "group2":1}}) and '
== f'!({HIDDEN}=true) and (access_control_list contains "user2" or access_control_list contains "group2") and '
)
def test_tenant_filter(self) -> None:
@@ -250,7 +250,7 @@ class TestBuildVespaFilters:
result = build_vespa_filters(filters)
expected = f"!({HIDDEN}=true) and "
expected += 'weightedSet(access_control_list, {"user1":1, "group1":1}) and '
expected += '(access_control_list contains "user1" or access_control_list contains "group1") and '
expected += f'({SOURCE_TYPE} contains "web") and '
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
# Knowledge scope filters are OR'd together
@@ -290,38 +290,6 @@ class TestBuildVespaFilters:
expected = f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "engineering") or ({DOCUMENT_ID} contains "{str(id1)}")) and '
assert expected == result
def test_acl_large_list_uses_weighted_set(self) -> None:
"""Verify that large ACL lists produce a weightedSet clause
instead of OR-chained contains — this is what prevents Vespa
HTTP 400 errors for users with thousands of permission groups."""
acl = [f"external_group:google_drive_{i}" for i in range(10_000)]
acl += ["user_email:user@example.com", "__PUBLIC__"]
filters = IndexFilters(access_control_list=acl)
result = build_vespa_filters(filters)
assert "weightedSet(access_control_list, {" in result
# Must NOT contain OR-chained contains clauses
assert "access_control_list contains" not in result
# All entries should be present
assert '"external_group:google_drive_0":1' in result
assert '"external_group:google_drive_9999":1' in result
assert '"user_email:user@example.com":1' in result
assert '"__PUBLIC__":1' in result
def test_acl_empty_strings_filtered(self) -> None:
"""Empty strings in the ACL list should be filtered out."""
filters = IndexFilters(access_control_list=["user1", "", "group1"])
result = build_vespa_filters(filters)
assert (
result
== f'!({HIDDEN}=true) and weightedSet(access_control_list, {{"user1":1, "group1":1}}) and '
)
# All empty
filters = IndexFilters(access_control_list=["", ""])
result = build_vespa_filters(filters)
assert result == f"!({HIDDEN}=true) and "
def test_empty_or_none_values(self) -> None:
"""Test with empty or None values in filter lists."""
# Empty strings in document set