Compare commits

..

4 Commits

Author SHA1 Message Date
Jamison Lahman
f0710af7a2 chore(fe): position memories modal at the top 2026-03-23 09:16:05 -07:00
Jamison Lahman
bd42c459d6 chore(fe): update memories dropdown padding (#9526) 2026-03-20 23:38:32 +00:00
Danelegend
aede532e63 fix(chat): Cache plaintext file results (#9511) 2026-03-20 23:21:12 +00:00
Evan Lohn
068ac543ad fix: deadlock in multitenant test (#9530) 2026-03-20 23:05:20 +00:00
13 changed files with 174 additions and 442 deletions

View File

@@ -30,6 +30,8 @@ from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.utils import plaintext_file_name_for_id
from onyx.file_store.utils import store_plaintext
from onyx.kg.models import KGException
from onyx.kg.setup.kg_default_entity_definitions import (
populate_missing_default_entity_types__commit,
@@ -289,6 +291,33 @@ def process_kg_commands(
raise KGException("KG setup done")
def _get_or_extract_plaintext(
file_id: str,
extract_fn: Callable[[], str],
) -> str:
"""Load cached plaintext for a file, or extract and store it.
Tries to read pre-stored plaintext from the file store. On a miss,
calls extract_fn to produce the text, then stores the result so
future calls skip the expensive extraction.
"""
file_store = get_default_file_store()
plaintext_key = plaintext_file_name_for_id(file_id)
# Try cached plaintext first.
try:
plaintext_io = file_store.read_file(plaintext_key, mode="b")
return plaintext_io.read().decode("utf-8")
except Exception:
logger.exception(f"Error when reading file, id={file_id}")
# Cache miss — extract and store.
content_text = extract_fn()
if content_text:
store_plaintext(file_id, content_text)
return content_text
@log_function_time(print_only=True)
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
@@ -303,12 +332,23 @@ def load_chat_file(
file_type = ChatFileType(file_descriptor["type"])
if file_type.is_text_file():
try:
content_text = extract_file_text(
file_id = file_descriptor["id"]
def _extract() -> str:
return extract_file_text(
file=file_io,
file_name=file_descriptor.get("name") or "",
break_on_unprocessable=False,
)
# Use the user_file_id as cache key when available (matches what
# the celery indexing worker stores), otherwise fall back to the
# file store id (covers code-interpreter-generated files, etc.).
user_file_id_str = file_descriptor.get("user_file_id")
cache_key = user_file_id_str or file_id
try:
content_text = _get_or_extract_plaintext(cache_key, _extract)
except Exception as e:
logger.warning(
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"

View File

@@ -59,7 +59,6 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -69,19 +68,11 @@ from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingPayload
from onyx.hooks.points.query_processing import QueryProcessingResponse
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
@@ -433,32 +424,6 @@ def determine_search_params(
)
def _apply_query_processing_hook(
hook_result: BaseModel | HookSkipped | HookSoftFailed,
message_text: str,
) -> str:
"""Apply the Query Processing hook result to the message text.
Returns the (possibly rewritten) message text, or raises OnyxError with
QUERY_REJECTED if the hook signals rejection (query is null or empty).
HookSkipped and HookSoftFailed are pass-throughs — the original text is
returned unchanged.
"""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return message_text
if not isinstance(hook_result, QueryProcessingResponse):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Expected QueryProcessingResponse from hook, got {type(hook_result).__name__}",
)
if not hook_result.query:
raise OnyxError(
OnyxErrorCode.QUERY_REJECTED,
hook_result.rejection_message or "Your query was rejected.",
)
return hook_result.query
def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
@@ -519,7 +484,6 @@ def handle_stream_message_objects(
persona = chat_session.persona
message_text = new_msg_req.message
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)
@@ -611,22 +575,6 @@ def handle_stream_message_objects(
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
# New message — run the Query Processing hook before saving to DB.
# Skipped on regeneration: the message already exists and was accepted previously.
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=QueryProcessingPayload(
query=message_text,
# Pass None for anonymous users or authenticated users without an email
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
# so None is accepted and serialised as null in both cases.
user_email=None if user.is_anonymous else user.email,
chat_session_id=str(chat_session.id),
).model_dump(),
)
message_text = _apply_query_processing_hook(hook_result, message_text)
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
@@ -966,17 +914,6 @@ def handle_stream_message_objects(
state_container=state_container,
)
except OnyxError as e:
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
log_onyx_error(e)
yield StreamingError(
error=e.detail,
error_code=e.error_code.code,
is_retryable=e.status_code >= 500,
)
db_session.rollback()
return
except ValueError as e:
logger.exception("Failed to process chat message.")

View File

@@ -44,7 +44,6 @@ class OnyxErrorCode(Enum):
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
QUERY_REJECTED = ("QUERY_REJECTED", 400)
# ------------------------------------------------------------------
# Not Found (404)

View File

@@ -23,45 +23,55 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return f"plaintext_{user_file_id}"
def plaintext_file_name_for_id(file_id: str) -> str:
"""Generate a consistent file name for storing plaintext content of a file."""
return f"plaintext_{file_id}"
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
"""
Store plaintext content for a user file in the file store.
Store plaintext content for a file in the file store.
Args:
user_file_id: The ID of the user file
file_id: The ID of the file (user_file or artifact_file)
plaintext_content: The plaintext content to store
Returns:
bool: True if storage was successful, False otherwise
"""
# Skip empty content
if not plaintext_content:
return False
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
plaintext_file_name = plaintext_file_name_for_id(file_id)
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for user file {user_file_id}",
display_name=f"Plaintext for {file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
except Exception as e:
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
return False
# --- Convenience wrappers for callers that use user-file UUIDs ---
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return plaintext_file_name_for_id(str(user_file_id))
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
return store_plaintext(str(user_file_id), plaintext_content)
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
"""Load a file directly from the file store using its file_record ID.

View File

@@ -14,7 +14,7 @@ Usage (Celery tasks and FastAPI handlers):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (spec.response_model)
# result is the response payload dict from the customer's endpoint
...
is_reachable update policy
@@ -56,7 +56,6 @@ from typing import Any
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -68,7 +67,6 @@ from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
@@ -270,21 +268,22 @@ def _persist_result(
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> BaseModel | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
hook_point = hook.hook_point # extract before HTTP call per design intent
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
@@ -301,37 +300,13 @@ def _execute_hook_inner(
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
with httpx.Client(timeout=timeout) as client:
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against the spec's response_model.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: BaseModel | None = None
if outcome.is_success and outcome.response_payload is not None:
spec = get_hook_point_spec(hook_point)
try:
validated_model = spec.response_model.model_validate(
outcome.response_payload
)
except ValidationError as e:
msg = f"Hook response failed validation against {spec.response_model.__name__}: {e}"
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
@@ -348,41 +323,8 @@ def _execute_hook_inner(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
if outcome.response_payload is None:
raise ValueError(
f"validated_model is None for successful hook call (hook_id={hook_id})"
f"response_payload is None for successful hook call (hook_id={hook_id})"
)
return validated_model
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> BaseModel | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
try:
return _execute_hook_inner(hook, payload)
except OnyxError:
# OnyxError(HOOK_EXECUTION_FAILED) is only raised under HARD strategy in
# _execute_hook_inner, so re-raise unconditionally — never silently swallow
# an OnyxError into HookSoftFailed, even under SOFT.
raise
except Exception:
if hook.fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook.id}"
)
return HookSoftFailed()
raise
return outcome.response_payload

View File

@@ -51,12 +51,13 @@ class HookPointSpec:
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every subclass declares all required class attributes.
"""Enforce that every concrete subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Raises TypeError at import time if any required attribute is missing or if
payload_model / response_model are not Pydantic BaseModel subclasses.
input_schema and output_schema are derived automatically from the models.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
"""
super().__init_subclass__(**kwargs)
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]

View File

@@ -15,7 +15,7 @@ class QueryProcessingPayload(BaseModel):
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
)

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
import os
import subprocess
import sys
import time
import uuid
from collections.abc import Generator
@@ -28,6 +29,9 @@ _BACKEND_DIR = os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
)
_DROP_SCHEMA_MAX_RETRIES = 3
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
# ---------------------------------------------------------------------------
# Helpers
@@ -50,6 +54,39 @@ def _run_script(
)
def _force_drop_schema(engine: Engine, schema: str) -> None:
"""Terminate backends using *schema* then drop it, retrying on deadlock.
Background Celery workers may discover test schemas (they match the
``tenant_`` prefix) and hold locks on tables inside them. A bare
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
first kill their connections and retry if we still hit a deadlock.
"""
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
try:
with engine.connect() as conn:
conn.execute(
text(
"""
SELECT pg_terminate_backend(l.pid)
FROM pg_locks l
JOIN pg_class c ON c.oid = l.relation
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema
AND l.pid != pg_backend_pid()
"""
),
{"schema": schema},
)
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
return
except Exception:
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
raise
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -104,9 +141,7 @@ def tenant_schema_at_head(
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
@pytest.fixture
@@ -123,9 +158,7 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
@pytest.fixture
@@ -150,9 +183,7 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
# ---------------------------------------------------------------------------

View File

@@ -1,12 +1,4 @@
import pytest
from onyx.chat.process_message import _apply_query_processing_hook
from onyx.chat.process_message import remove_answer_citations
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
@@ -40,83 +32,3 @@ def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None
remove_answer_citations(answer)
== "See [reference](https://example.com/Function_(mathematics)) for context."
)
# ---------------------------------------------------------------------------
# Query Processing hook response handling (_apply_query_processing_hook)
# ---------------------------------------------------------------------------
def test_wrong_model_type_raises_internal_error() -> None:
"""If the executor ever returns an unexpected BaseModel type, raise INTERNAL_ERROR
rather than an AssertionError or AttributeError."""
from pydantic import BaseModel as PydanticBaseModel
class _OtherModel(PydanticBaseModel):
pass
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(_OtherModel(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
def test_hook_skipped_leaves_message_text_unchanged() -> None:
result = _apply_query_processing_hook(HookSkipped(), "original query")
assert result == "original query"
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
result = _apply_query_processing_hook(HookSoftFailed(), "original query")
assert result == "original query"
def test_null_query_raises_query_rejected() -> None:
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=None), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_empty_string_query_raises_query_rejected() -> None:
"""Empty string is falsy — must be treated as rejection, same as None."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=""), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_absent_query_field_raises_query_rejected() -> None:
"""query defaults to None when not provided."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(QueryProcessingResponse(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_rejection_message_surfaced_in_error_when_provided() -> None:
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(
query=None, rejection_message="Queries about X are not allowed."
),
"original query",
)
assert "Queries about X are not allowed." in str(exc_info.value)
def test_fallback_rejection_message_when_none() -> None:
"""No rejection_message → generic fallback used in OnyxError detail."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=None, rejection_message=None),
"original query",
)
assert "Your query was rejected." in str(exc_info.value)
def test_nonempty_query_rewrites_message_text() -> None:
result = _apply_query_processing_hook(
QueryProcessingResponse(query="rewritten query"), "original query"
)
assert result == "rewritten query"

View File

@@ -7,7 +7,6 @@ from unittest.mock import patch
import httpx
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
@@ -16,15 +15,13 @@ from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
# A valid QueryProcessingResponse payload — used by success-path tests.
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
def _make_hook(
@@ -36,7 +33,6 @@ def _make_hook(
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
@@ -46,7 +42,6 @@ def _make_hook(
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
hook.hook_point = hook_point
return hook
@@ -157,9 +152,7 @@ def test_early_exit_returns_skipped_with_no_db_writes(
# ---------------------------------------------------------------------------
def test_success_returns_validated_model_and_sets_reachable(
db_session: MagicMock,
) -> None:
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
hook = _make_hook()
with (
@@ -180,8 +173,7 @@ def test_success_returns_validated_model_and_sets_reachable(
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
assert result == _RESPONSE_PAYLOAD
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
@@ -210,8 +202,7 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
assert result == _RESPONSE_PAYLOAD
mock_update.assert_not_called()
@@ -466,16 +457,16 @@ def test_authorization_header(
@pytest.mark.parametrize(
"http_exception,expect_onyx_error",
"http_exception,expected_result",
[
pytest.param(None, False, id="success_path"),
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expect_onyx_error: bool,
expected_result: Any,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
@@ -498,7 +489,7 @@ def test_persist_session_failure_is_swallowed(
side_effect=http_exception,
)
if expect_onyx_error:
if expected_result is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
@@ -512,162 +503,7 @@ def test_persist_session_failure_is_swallowed(
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
# ---------------------------------------------------------------------------
# Response model validation
# ---------------------------------------------------------------------------
class _StrictResponse(BaseModel):
"""Strict model used to reliably trigger a ValidationError in tests."""
required_field: str # no default → missing key raises ValidationError
def _make_strict_spec() -> MagicMock:
spec = MagicMock()
spec.response_model = _StrictResponse
return spec
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
),
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
],
)
def test_response_validation_failure_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""A response that fails response_model validation is treated like any other
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch(
"onyx.hooks.executor.get_hook_point_spec",
return_value=_make_strict_spec(),
),
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
_setup_client(mock_client_cls, response=_make_response(json_return={}))
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
# is_reachable must not be updated — server responded correctly
mock_update.assert_not_called()
# failure must be logged
mock_log.assert_called_once()
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "validation" in (log_kwargs["error_message"] or "").lower()
# ---------------------------------------------------------------------------
# Outer soft-fail guard in execute_hook
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
],
)
def test_unexpected_exception_in_inner_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
if expected_type is HookSoftFailed:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
else:
with pytest.raises(ValueError, match="unexpected bug"):
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
def test_onyx_error_from_inner_is_never_swallowed_under_soft(
db_session: MagicMock,
) -> None:
"""OnyxError raised by _execute_hook_inner must be re-raised even under SOFT
strategy — the except OnyxError: raise guard must not let it get swallowed."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=OnyxError(OnyxErrorCode.HOOK_EXECUTION_FAILED, "hard fail"),
),
):
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
assert result == expected_result
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:

View File

@@ -56,8 +56,14 @@ function MemoryTagWithTooltip({
side="bottom"
className="bg-background-neutral-00 text-text-01 shadow-md max-w-[17.5rem] p-1"
tooltip={
<Section flexDirection="column" gap={0.25} height="auto">
<div className="p-1 w-full">
<Section
flexDirection="column"
alignItems="start"
padding={0.25}
gap={0.25}
height="auto"
>
<div className="p-1">
<Text as="p" secondaryBody text03>
{memoryText}
</Text>
@@ -66,6 +72,7 @@ function MemoryTagWithTooltip({
icon={SvgAddLines}
title={operationLabel}
sizePreset="secondary"
paddingVariant="sm"
variant="body"
prominence="muted"
rightChildren={

View File

@@ -120,6 +120,10 @@ export interface ModalContentProps
> {
width?: keyof typeof widthClasses;
height?: keyof typeof heightClasses;
/** Vertical placement of the modal. `"center"` (default) centers in the
* viewport/container. `"top"` pins the modal near the top of the viewport,
* matching the position used by CommandMenu. */
position?: "center" | "top";
preventAccidentalClose?: boolean;
skipOverlay?: boolean;
background?: "default" | "gray";
@@ -136,6 +140,7 @@ const ModalContent = React.forwardRef<
children,
width = "md",
height = "fit",
position = "center",
preventAccidentalClose = true,
skipOverlay = false,
background = "default",
@@ -267,27 +272,39 @@ const ModalContent = React.forwardRef<
const { centerX, centerY, hasContainerCenter } = useContainerCenter();
const isTop = position === "top";
const animationClasses = cn(
"data-[state=open]:fade-in-0 data-[state=closed]:fade-out-0",
"data-[state=open]:zoom-in-95 data-[state=closed]:zoom-out-95",
"data-[state=open]:slide-in-from-top-1/2 data-[state=closed]:slide-out-to-top-1/2",
!isTop &&
"data-[state=open]:slide-in-from-top-1/2 data-[state=closed]:slide-out-to-top-1/2",
"duration-200"
);
const containerStyle: React.CSSProperties | undefined = hasContainerCenter
? ({
left: centerX,
top: centerY,
"--tw-enter-translate-x": "-50%",
"--tw-exit-translate-x": "-50%",
"--tw-enter-translate-y": "-50%",
"--tw-exit-translate-y": "-50%",
} as React.CSSProperties)
: undefined;
const containerStyle: React.CSSProperties | undefined =
hasContainerCenter && !isTop
? ({
left: centerX,
top: centerY,
"--tw-enter-translate-x": "-50%",
"--tw-exit-translate-x": "-50%",
"--tw-enter-translate-y": "-50%",
"--tw-exit-translate-y": "-50%",
} as React.CSSProperties)
: hasContainerCenter && isTop
? ({
left: centerX,
"--tw-enter-translate-x": "-50%",
"--tw-exit-translate-x": "-50%",
} as React.CSSProperties)
: undefined;
const positionClasses = cn(
"fixed -translate-x-1/2 -translate-y-1/2",
!hasContainerCenter && "left-1/2 top-1/2"
"fixed -translate-x-1/2",
isTop
? cn("top-[72px]", !hasContainerCenter && "left-1/2")
: cn("-translate-y-1/2", !hasContainerCenter && "left-1/2 top-1/2")
);
const dialogEventHandlers = {

View File

@@ -228,7 +228,7 @@ export default function MemoriesModal({
return (
<Modal open onOpenChange={(open) => !open && close?.()}>
<Modal.Content width="sm" height="lg">
<Modal.Content width="sm" height="lg" position="top">
<Modal.Header
icon={SvgAddLines}
title="Memory"