Compare commits

...

2 Commits

Author SHA1 Message Date
Bo-Onyx
087aa945f5 address comments 2026-03-22 11:51:30 -07:00
Bo-Onyx
f77d5d2d01 feat(hook): integrate query processing hook point 2026-03-20 17:37:01 -07:00
7 changed files with 403 additions and 30 deletions

View File

@@ -59,6 +59,7 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -68,11 +69,19 @@ from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingPayload
from onyx.hooks.points.query_processing import QueryProcessingResponse
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
@@ -424,6 +433,32 @@ def determine_search_params(
)
def _apply_query_processing_hook(
hook_result: BaseModel | HookSkipped | HookSoftFailed,
message_text: str,
) -> str:
"""Apply the Query Processing hook result to the message text.
Returns the (possibly rewritten) message text, or raises OnyxError with
QUERY_REJECTED if the hook signals rejection (query is null or empty).
HookSkipped and HookSoftFailed are pass-throughs — the original text is
returned unchanged.
"""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return message_text
if not isinstance(hook_result, QueryProcessingResponse):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Expected QueryProcessingResponse from hook, got {type(hook_result).__name__}",
)
if not hook_result.query:
raise OnyxError(
OnyxErrorCode.QUERY_REJECTED,
hook_result.rejection_message or "Your query was rejected.",
)
return hook_result.query
def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
@@ -484,6 +519,7 @@ def handle_stream_message_objects(
persona = chat_session.persona
message_text = new_msg_req.message
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)
@@ -575,6 +611,22 @@ def handle_stream_message_objects(
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
# New message — run the Query Processing hook before saving to DB.
# Skipped on regeneration: the message already exists and was accepted previously.
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=QueryProcessingPayload(
query=message_text,
# Pass None for anonymous users or authenticated users without an email
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
# so None is accepted and serialised as null in both cases.
user_email=None if user.is_anonymous else user.email,
chat_session_id=str(chat_session.id),
).model_dump(),
)
message_text = _apply_query_processing_hook(hook_result, message_text)
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
@@ -914,6 +966,17 @@ def handle_stream_message_objects(
state_container=state_container,
)
except OnyxError as e:
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
log_onyx_error(e)
yield StreamingError(
error=e.detail,
error_code=e.error_code.code,
is_retryable=e.status_code >= 500,
)
db_session.rollback()
return
except ValueError as e:
logger.exception("Failed to process chat message.")

View File

@@ -44,6 +44,7 @@ class OnyxErrorCode(Enum):
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
QUERY_REJECTED = ("QUERY_REJECTED", 400)
# ------------------------------------------------------------------
# Not Found (404)

View File

@@ -14,7 +14,7 @@ Usage (Celery tasks and FastAPI handlers):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is the response payload dict from the customer's endpoint
# result is a validated Pydantic model instance (spec.response_model)
...
is_reachable update policy
@@ -56,6 +56,7 @@ from typing import Any
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -67,6 +68,7 @@ from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
@@ -268,22 +270,21 @@ def _persist_result(
# ---------------------------------------------------------------------------
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
) -> BaseModel | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
hook_point = hook.hook_point # extract before HTTP call per design intent
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
@@ -300,13 +301,37 @@ def execute_hook(
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(timeout=timeout) as client:
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against the spec's response_model.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: BaseModel | None = None
if outcome.is_success and outcome.response_payload is not None:
spec = get_hook_point_spec(hook_point)
try:
validated_model = spec.response_model.model_validate(
outcome.response_payload
)
except ValidationError as e:
msg = f"Hook response failed validation against {spec.response_model.__name__}: {e}"
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
@@ -323,8 +348,41 @@ def execute_hook(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if outcome.response_payload is None:
if validated_model is None:
raise ValueError(
f"response_payload is None for successful hook call (hook_id={hook_id})"
f"validated_model is None for successful hook call (hook_id={hook_id})"
)
return outcome.response_payload
return validated_model
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> BaseModel | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
try:
return _execute_hook_inner(hook, payload)
except OnyxError:
# OnyxError(HOOK_EXECUTION_FAILED) is only raised under HARD strategy in
# _execute_hook_inner, so re-raise unconditionally — never silently swallow
# an OnyxError into HookSoftFailed, even under SOFT.
raise
except Exception:
if hook.fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook.id}"
)
return HookSoftFailed()
raise

View File

@@ -51,13 +51,12 @@ class HookPointSpec:
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every concrete subclass declares all required class attributes.
"""Enforce that every subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
Raises TypeError at import time if any required attribute is missing or if
payload_model / response_model are not Pydantic BaseModel subclasses.
input_schema and output_schema are derived automatically from the models.
"""
super().__init_subclass__(**kwargs)
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]

View File

@@ -15,7 +15,7 @@ class QueryProcessingPayload(BaseModel):
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
)

View File

@@ -1,4 +1,12 @@
import pytest
from onyx.chat.process_message import _apply_query_processing_hook
from onyx.chat.process_message import remove_answer_citations
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
@@ -32,3 +40,83 @@ def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None
remove_answer_citations(answer)
== "See [reference](https://example.com/Function_(mathematics)) for context."
)
# ---------------------------------------------------------------------------
# Query Processing hook response handling (_apply_query_processing_hook)
# ---------------------------------------------------------------------------
def test_wrong_model_type_raises_internal_error() -> None:
"""If the executor ever returns an unexpected BaseModel type, raise INTERNAL_ERROR
rather than an AssertionError or AttributeError."""
from pydantic import BaseModel as PydanticBaseModel
class _OtherModel(PydanticBaseModel):
pass
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(_OtherModel(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
def test_hook_skipped_leaves_message_text_unchanged() -> None:
result = _apply_query_processing_hook(HookSkipped(), "original query")
assert result == "original query"
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
result = _apply_query_processing_hook(HookSoftFailed(), "original query")
assert result == "original query"
def test_null_query_raises_query_rejected() -> None:
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=None), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_empty_string_query_raises_query_rejected() -> None:
"""Empty string is falsy — must be treated as rejection, same as None."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=""), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_absent_query_field_raises_query_rejected() -> None:
"""query defaults to None when not provided."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(QueryProcessingResponse(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_rejection_message_surfaced_in_error_when_provided() -> None:
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(
query=None, rejection_message="Queries about X are not allowed."
),
"original query",
)
assert "Queries about X are not allowed." in str(exc_info.value)
def test_fallback_rejection_message_when_none() -> None:
"""No rejection_message → generic fallback used in OnyxError detail."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=None, rejection_message=None),
"original query",
)
assert "Your query was rejected." in str(exc_info.value)
def test_nonempty_query_rewrites_message_text() -> None:
result = _apply_query_processing_hook(
QueryProcessingResponse(query="rewritten query"), "original query"
)
assert result == "rewritten query"

View File

@@ -7,6 +7,7 @@ from unittest.mock import patch
import httpx
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
@@ -15,13 +16,15 @@ from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
# A valid QueryProcessingResponse payload — used by success-path tests.
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
def _make_hook(
@@ -33,6 +36,7 @@ def _make_hook(
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
@@ -42,6 +46,7 @@ def _make_hook(
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
hook.hook_point = hook_point
return hook
@@ -152,7 +157,9 @@ def test_early_exit_returns_skipped_with_no_db_writes(
# ---------------------------------------------------------------------------
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
def test_success_returns_validated_model_and_sets_reachable(
db_session: MagicMock,
) -> None:
hook = _make_hook()
with (
@@ -173,7 +180,8 @@ def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> No
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
@@ -202,7 +210,8 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
mock_update.assert_not_called()
@@ -457,16 +466,16 @@ def test_authorization_header(
@pytest.mark.parametrize(
"http_exception,expected_result",
"http_exception,expect_onyx_error",
[
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
pytest.param(None, False, id="success_path"),
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expected_result: Any,
expect_onyx_error: bool,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
@@ -489,7 +498,7 @@ def test_persist_session_failure_is_swallowed(
side_effect=http_exception,
)
if expected_result is OnyxError:
if expect_onyx_error:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
@@ -503,7 +512,162 @@ def test_persist_session_failure_is_swallowed(
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == expected_result
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
# ---------------------------------------------------------------------------
# Response model validation
# ---------------------------------------------------------------------------
class _StrictResponse(BaseModel):
"""Strict model used to reliably trigger a ValidationError in tests."""
required_field: str # no default → missing key raises ValidationError
def _make_strict_spec() -> MagicMock:
spec = MagicMock()
spec.response_model = _StrictResponse
return spec
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
),
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
],
)
def test_response_validation_failure_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""A response that fails response_model validation is treated like any other
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch(
"onyx.hooks.executor.get_hook_point_spec",
return_value=_make_strict_spec(),
),
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
_setup_client(mock_client_cls, response=_make_response(json_return={}))
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
# is_reachable must not be updated — server responded correctly
mock_update.assert_not_called()
# failure must be logged
mock_log.assert_called_once()
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "validation" in (log_kwargs["error_message"] or "").lower()
# ---------------------------------------------------------------------------
# Outer soft-fail guard in execute_hook
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
],
)
def test_unexpected_exception_in_inner_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
if expected_type is HookSoftFailed:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
else:
with pytest.raises(ValueError, match="unexpected bug"):
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
def test_onyx_error_from_inner_is_never_swallowed_under_soft(
db_session: MagicMock,
) -> None:
"""OnyxError raised by _execute_hook_inner must be re-raised even under SOFT
strategy — the except OnyxError: raise guard must not let it get swallowed."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=OnyxError(OnyxErrorCode.HOOK_EXECUTION_FAILED, "hard fail"),
),
):
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None: