Compare commits

..

9 Commits

Author SHA1 Message Date
Evan Lohn
5186356a26 fix: windows install improvements (#9542) 2026-03-22 20:59:35 +00:00
Jamison Lahman
7b826e2a4e chore(fe): auto-focus clicked memory, improve action hover style (#9532) 2026-03-22 19:16:10 +00:00
Justin Tahara
c175dc8f6a fix(mt): Tenant Provisioning Fixes (#9541) 2026-03-22 17:50:00 +00:00
Raunak Bhagat
aa11813cc0 feat: UserAvatar (#9527) 2026-03-21 02:05:00 +00:00
Evan Lohn
6235f49b49 fix: csv test with newlines (#9534) 2026-03-21 01:30:11 +00:00
Evan Lohn
fd6a110794 feat: installer invocable from other bash script (#9531) 2026-03-21 01:18:20 +00:00
Jamison Lahman
bd42c459d6 chore(fe): update memories dropdown padding (#9526) 2026-03-20 23:38:32 +00:00
Danelegend
aede532e63 fix(chat): Cache plaintext file results (#9511) 2026-03-20 23:21:12 +00:00
Evan Lohn
068ac543ad fix: deadlock in multitenant test (#9530) 2026-03-20 23:05:20 +00:00
26 changed files with 413 additions and 580 deletions

View File

@@ -25,9 +25,6 @@ from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
@@ -58,7 +55,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
)
# These tasks should never overlap
@@ -74,9 +71,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
num_available_tenants = db_session.query(AvailableTenant).count()
# Get the target number of available tenants
num_minimum_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
)
num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS
# Calculate how many new tenants we need to provision
if num_available_tenants < num_minimum_available_tenants:
@@ -98,7 +93,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
task_logger.exception("Error in check_available_tenants task")
finally:
lock_check.release()
try:
lock_check.release()
except Exception:
task_logger.warning(
"Could not release check lock (likely expired), continuing"
)
def pre_provision_tenant() -> None:
@@ -113,7 +113,7 @@ def pre_provision_tenant() -> None:
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
)
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
@@ -185,4 +185,9 @@ def pre_provision_tenant() -> None:
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
finally:
lock_provision.release()
try:
lock_provision.release()
except Exception:
task_logger.warning(
"Could not release provision lock (likely expired), continuing"
)

View File

@@ -30,6 +30,8 @@ from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.utils import plaintext_file_name_for_id
from onyx.file_store.utils import store_plaintext
from onyx.kg.models import KGException
from onyx.kg.setup.kg_default_entity_definitions import (
populate_missing_default_entity_types__commit,
@@ -289,6 +291,33 @@ def process_kg_commands(
raise KGException("KG setup done")
def _get_or_extract_plaintext(
file_id: str,
extract_fn: Callable[[], str],
) -> str:
"""Load cached plaintext for a file, or extract and store it.
Tries to read pre-stored plaintext from the file store. On a miss,
calls extract_fn to produce the text, then stores the result so
future calls skip the expensive extraction.
"""
file_store = get_default_file_store()
plaintext_key = plaintext_file_name_for_id(file_id)
# Try cached plaintext first.
try:
plaintext_io = file_store.read_file(plaintext_key, mode="b")
return plaintext_io.read().decode("utf-8")
except Exception:
logger.exception(f"Error when reading file, id={file_id}")
# Cache miss — extract and store.
content_text = extract_fn()
if content_text:
store_plaintext(file_id, content_text)
return content_text
@log_function_time(print_only=True)
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
@@ -303,12 +332,23 @@ def load_chat_file(
file_type = ChatFileType(file_descriptor["type"])
if file_type.is_text_file():
try:
content_text = extract_file_text(
file_id = file_descriptor["id"]
def _extract() -> str:
return extract_file_text(
file=file_io,
file_name=file_descriptor.get("name") or "",
break_on_unprocessable=False,
)
# Use the user_file_id as cache key when available (matches what
# the celery indexing worker stores), otherwise fall back to the
# file store id (covers code-interpreter-generated files, etc.).
user_file_id_str = file_descriptor.get("user_file_id")
cache_key = user_file_id_str or file_id
try:
content_text = _get_or_extract_plaintext(cache_key, _extract)
except Exception as e:
logger.warning(
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"

View File

@@ -59,7 +59,6 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -69,19 +68,11 @@ from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingPayload
from onyx.hooks.points.query_processing import QueryProcessingResponse
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
@@ -433,32 +424,6 @@ def determine_search_params(
)
def _apply_query_processing_hook(
hook_result: BaseModel | HookSkipped | HookSoftFailed,
message_text: str,
) -> str:
"""Apply the Query Processing hook result to the message text.
Returns the (possibly rewritten) message text, or raises OnyxError with
QUERY_REJECTED if the hook signals rejection (query is null or empty).
HookSkipped and HookSoftFailed are pass-throughs — the original text is
returned unchanged.
"""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return message_text
if not isinstance(hook_result, QueryProcessingResponse):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Expected QueryProcessingResponse from hook, got {type(hook_result).__name__}",
)
if not hook_result.query:
raise OnyxError(
OnyxErrorCode.QUERY_REJECTED,
hook_result.rejection_message or "Your query was rejected.",
)
return hook_result.query
def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
@@ -519,7 +484,6 @@ def handle_stream_message_objects(
persona = chat_session.persona
message_text = new_msg_req.message
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)
@@ -611,22 +575,6 @@ def handle_stream_message_objects(
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
# New message — run the Query Processing hook before saving to DB.
# Skipped on regeneration: the message already exists and was accepted previously.
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=QueryProcessingPayload(
query=message_text,
# Pass None for anonymous users or authenticated users without an email
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
# so None is accepted and serialised as null in both cases.
user_email=None if user.is_anonymous else user.email,
chat_session_id=str(chat_session.id),
).model_dump(),
)
message_text = _apply_query_processing_hook(hook_result, message_text)
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
@@ -966,17 +914,6 @@ def handle_stream_message_objects(
state_container=state_container,
)
except OnyxError as e:
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
log_onyx_error(e)
yield StreamingError(
error=e.detail,
error_code=e.error_code.code,
is_retryable=e.status_code >= 500,
)
db_session.rollback()
return
except ValueError as e:
logger.exception("Failed to process chat message.")

View File

@@ -44,7 +44,6 @@ class OnyxErrorCode(Enum):
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
QUERY_REJECTED = ("QUERY_REJECTED", 400)
# ------------------------------------------------------------------
# Not Found (404)

View File

@@ -23,45 +23,55 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return f"plaintext_{user_file_id}"
def plaintext_file_name_for_id(file_id: str) -> str:
"""Generate a consistent file name for storing plaintext content of a file."""
return f"plaintext_{file_id}"
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
"""
Store plaintext content for a user file in the file store.
Store plaintext content for a file in the file store.
Args:
user_file_id: The ID of the user file
file_id: The ID of the file (user_file or artifact_file)
plaintext_content: The plaintext content to store
Returns:
bool: True if storage was successful, False otherwise
"""
# Skip empty content
if not plaintext_content:
return False
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
plaintext_file_name = plaintext_file_name_for_id(file_id)
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for user file {user_file_id}",
display_name=f"Plaintext for {file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
except Exception as e:
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
return False
# --- Convenience wrappers for callers that use user-file UUIDs ---
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return plaintext_file_name_for_id(str(user_file_id))
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
return store_plaintext(str(user_file_id), plaintext_content)
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
"""Load a file directly from the file store using its file_record ID.

View File

@@ -14,7 +14,7 @@ Usage (Celery tasks and FastAPI handlers):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is a validated Pydantic model instance (spec.response_model)
# result is the response payload dict from the customer's endpoint
...
is_reachable update policy
@@ -56,7 +56,6 @@ from typing import Any
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -68,7 +67,6 @@ from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
@@ -270,21 +268,22 @@ def _persist_result(
# ---------------------------------------------------------------------------
def _execute_hook_inner(
hook: Hook,
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> BaseModel | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
hook_point = hook.hook_point # extract before HTTP call per design intent
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
@@ -301,37 +300,13 @@ def _execute_hook_inner(
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
with httpx.Client(timeout=timeout) as client:
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against the spec's response_model.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: BaseModel | None = None
if outcome.is_success and outcome.response_payload is not None:
spec = get_hook_point_spec(hook_point)
try:
validated_model = spec.response_model.model_validate(
outcome.response_payload
)
except ValidationError as e:
msg = f"Hook response failed validation against {spec.response_model.__name__}: {e}"
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
@@ -348,41 +323,8 @@ def _execute_hook_inner(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if validated_model is None:
if outcome.response_payload is None:
raise ValueError(
f"validated_model is None for successful hook call (hook_id={hook_id})"
f"response_payload is None for successful hook call (hook_id={hook_id})"
)
return validated_model
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> BaseModel | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
try:
return _execute_hook_inner(hook, payload)
except OnyxError:
# OnyxError(HOOK_EXECUTION_FAILED) is only raised under HARD strategy in
# _execute_hook_inner, so re-raise unconditionally — never silently swallow
# an OnyxError into HookSoftFailed, even under SOFT.
raise
except Exception:
if hook.fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook.id}"
)
return HookSoftFailed()
raise
return outcome.response_payload

View File

@@ -51,12 +51,13 @@ class HookPointSpec:
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every subclass declares all required class attributes.
"""Enforce that every concrete subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Raises TypeError at import time if any required attribute is missing or if
payload_model / response_model are not Pydantic BaseModel subclasses.
input_schema and output_schema are derived automatically from the models.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
"""
super().__init_subclass__(**kwargs)
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]

View File

@@ -15,7 +15,7 @@ class QueryProcessingPayload(BaseModel):
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
)

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
import os
import subprocess
import sys
import time
import uuid
from collections.abc import Generator
@@ -28,6 +29,9 @@ _BACKEND_DIR = os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
)
_DROP_SCHEMA_MAX_RETRIES = 3
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
# ---------------------------------------------------------------------------
# Helpers
@@ -50,6 +54,39 @@ def _run_script(
)
def _force_drop_schema(engine: Engine, schema: str) -> None:
"""Terminate backends using *schema* then drop it, retrying on deadlock.
Background Celery workers may discover test schemas (they match the
``tenant_`` prefix) and hold locks on tables inside them. A bare
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
first kill their connections and retry if we still hit a deadlock.
"""
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
try:
with engine.connect() as conn:
conn.execute(
text(
"""
SELECT pg_terminate_backend(l.pid)
FROM pg_locks l
JOIN pg_class c ON c.oid = l.relation
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema
AND l.pid != pg_backend_pid()
"""
),
{"schema": schema},
)
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
return
except Exception:
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
raise
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -104,9 +141,7 @@ def tenant_schema_at_head(
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
@pytest.fixture
@@ -123,9 +158,7 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
@pytest.fixture
@@ -150,9 +183,7 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
yield schema
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
_force_drop_schema(engine, schema)
# ---------------------------------------------------------------------------

View File

@@ -1,3 +1,5 @@
import csv
import io
import os
from datetime import datetime
from datetime import timedelta
@@ -139,12 +141,12 @@ def test_chat_history_csv_export(
assert headers["Content-Type"] == "text/csv; charset=utf-8"
assert "Content-Disposition" in headers
# Verify CSV content
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 3 # Header + 2 QA pairs
assert "chat_session_id" in csv_content
assert "user_message" in csv_content
assert "ai_response" in csv_content
# Use csv.reader to properly handle newlines inside quoted fields
csv_rows = list(csv.reader(io.StringIO(csv_content)))
assert len(csv_rows) == 3 # Header + 2 QA pairs
assert csv_rows[0][0] == "chat_session_id"
assert "user_message" in csv_rows[0]
assert "ai_response" in csv_rows[0]
assert "What was the Q1 revenue?" in csv_content
assert "What about Q2 revenue?" in csv_content
@@ -156,5 +158,5 @@ def test_chat_history_csv_export(
end_time=past_end,
user_performing_action=admin_user,
)
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 1 # Only header, no data rows
csv_rows = list(csv.reader(io.StringIO(csv_content)))
assert len(csv_rows) == 1 # Only header, no data rows

View File

@@ -1,12 +1,4 @@
import pytest
from onyx.chat.process_message import _apply_query_processing_hook
from onyx.chat.process_message import remove_answer_citations
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
@@ -40,83 +32,3 @@ def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None
remove_answer_citations(answer)
== "See [reference](https://example.com/Function_(mathematics)) for context."
)
# ---------------------------------------------------------------------------
# Query Processing hook response handling (_apply_query_processing_hook)
# ---------------------------------------------------------------------------
def test_wrong_model_type_raises_internal_error() -> None:
"""If the executor ever returns an unexpected BaseModel type, raise INTERNAL_ERROR
rather than an AssertionError or AttributeError."""
from pydantic import BaseModel as PydanticBaseModel
class _OtherModel(PydanticBaseModel):
pass
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(_OtherModel(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
def test_hook_skipped_leaves_message_text_unchanged() -> None:
result = _apply_query_processing_hook(HookSkipped(), "original query")
assert result == "original query"
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
result = _apply_query_processing_hook(HookSoftFailed(), "original query")
assert result == "original query"
def test_null_query_raises_query_rejected() -> None:
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=None), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_empty_string_query_raises_query_rejected() -> None:
"""Empty string is falsy — must be treated as rejection, same as None."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=""), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_absent_query_field_raises_query_rejected() -> None:
"""query defaults to None when not provided."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(QueryProcessingResponse(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_rejection_message_surfaced_in_error_when_provided() -> None:
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(
query=None, rejection_message="Queries about X are not allowed."
),
"original query",
)
assert "Queries about X are not allowed." in str(exc_info.value)
def test_fallback_rejection_message_when_none() -> None:
"""No rejection_message → generic fallback used in OnyxError detail."""
with pytest.raises(OnyxError) as exc_info:
_apply_query_processing_hook(
QueryProcessingResponse(query=None, rejection_message=None),
"original query",
)
assert "Your query was rejected." in str(exc_info.value)
def test_nonempty_query_rewrites_message_text() -> None:
result = _apply_query_processing_hook(
QueryProcessingResponse(query="rewritten query"), "original query"
)
assert result == "rewritten query"

View File

@@ -7,7 +7,6 @@ from unittest.mock import patch
import httpx
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
@@ -16,15 +15,13 @@ from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
# A valid QueryProcessingResponse payload — used by success-path tests.
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
def _make_hook(
@@ -36,7 +33,6 @@ def _make_hook(
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
@@ -46,7 +42,6 @@ def _make_hook(
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
hook.hook_point = hook_point
return hook
@@ -157,9 +152,7 @@ def test_early_exit_returns_skipped_with_no_db_writes(
# ---------------------------------------------------------------------------
def test_success_returns_validated_model_and_sets_reachable(
db_session: MagicMock,
) -> None:
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
hook = _make_hook()
with (
@@ -180,8 +173,7 @@ def test_success_returns_validated_model_and_sets_reachable(
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
assert result == _RESPONSE_PAYLOAD
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
@@ -210,8 +202,7 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
assert result == _RESPONSE_PAYLOAD
mock_update.assert_not_called()
@@ -466,16 +457,16 @@ def test_authorization_header(
@pytest.mark.parametrize(
"http_exception,expect_onyx_error",
"http_exception,expected_result",
[
pytest.param(None, False, id="success_path"),
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expect_onyx_error: bool,
expected_result: Any,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
@@ -498,7 +489,7 @@ def test_persist_session_failure_is_swallowed(
side_effect=http_exception,
)
if expect_onyx_error:
if expected_result is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
@@ -512,162 +503,7 @@ def test_persist_session_failure_is_swallowed(
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
# ---------------------------------------------------------------------------
# Response model validation
# ---------------------------------------------------------------------------
class _StrictResponse(BaseModel):
"""Strict model used to reliably trigger a ValidationError in tests."""
required_field: str # no default → missing key raises ValidationError
def _make_strict_spec() -> MagicMock:
spec = MagicMock()
spec.response_model = _StrictResponse
return spec
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
),
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
],
)
def test_response_validation_failure_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""A response that fails response_model validation is treated like any other
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch(
"onyx.hooks.executor.get_hook_point_spec",
return_value=_make_strict_spec(),
),
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
_setup_client(mock_client_cls, response=_make_response(json_return={}))
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
# is_reachable must not be updated — server responded correctly
mock_update.assert_not_called()
# failure must be logged
mock_log.assert_called_once()
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "validation" in (log_kwargs["error_message"] or "").lower()
# ---------------------------------------------------------------------------
# Outer soft-fail guard in execute_hook
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
],
)
def test_unexpected_exception_in_inner_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
if expected_type is HookSoftFailed:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
else:
with pytest.raises(ValueError, match="unexpected bug"):
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
def test_onyx_error_from_inner_is_never_swallowed_under_soft(
db_session: MagicMock,
) -> None:
"""OnyxError raised by _execute_hook_inner must be re-raised even under SOFT
strategy — the except OnyxError: raise guard must not let it get swallowed."""
hook = _make_hook(fail_strategy=HookFailStrategy.SOFT)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=OnyxError(OnyxErrorCode.HOOK_EXECUTION_FAILED, "hard fail"),
),
):
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
assert result == expected_result
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:

View File

@@ -70,8 +70,8 @@ function Prompt-OrDefault {
function Confirm-Action {
param([string]$Description)
$reply = Prompt-OrDefault "Install $Description? (Y/n) [default: Y]" "Y"
if ($reply -match '^[Nn]') {
$reply = (Prompt-OrDefault "Install $Description? (Y/n) [default: Y]" "Y").Trim().ToLower()
if ($reply -match '^n') {
Print-Warning "Skipping: $Description"
return $false
}
@@ -364,7 +364,7 @@ function Invoke-OnyxDeleteData {
Write-Host "`n=== WARNING: This will permanently delete all Onyx data ===`n" -ForegroundColor Red
Print-Warning "This action will remove all Onyx containers, volumes, files, and user data."
if (Test-Interactive) {
$confirm = Read-Host "Type 'DELETE' to confirm"
$confirm = Prompt-OrDefault "Type 'DELETE' to confirm" ""
if ($confirm -ne "DELETE") { Print-Info "Operation cancelled."; return }
} else {
Print-OnyxError "Cannot confirm destructive operation in non-interactive mode."
@@ -720,6 +720,7 @@ function Invoke-WslInstall {
# Ensure WSL2 is available
Invoke-NativeQuiet { wsl --status }
if ($LASTEXITCODE -ne 0) {
if (-not (Confirm-Action "WSL2 (Windows Subsystem for Linux)")) { exit 1 }
Print-Info "Installing WSL2..."
try {
$proc = Start-Process wsl -ArgumentList "--install", "--no-distribution" -Wait -PassThru -NoNewWindow
@@ -806,7 +807,7 @@ function Main {
if (Test-Interactive) {
Write-Host "`nPlease acknowledge and press Enter to continue..." -ForegroundColor Yellow
Read-Host | Out-Null
$null = Prompt-OrDefault "" ""
} else {
Write-Host "`nRunning in non-interactive mode - proceeding automatically..." -ForegroundColor Yellow
}
@@ -902,8 +903,8 @@ function Main {
if ($resourceWarning) {
Print-Warning "Onyx recommends at least $($script:ExpectedDockerRamGB)GB RAM and $($script:ExpectedDiskGB)GB disk for standard mode."
Print-Warning "Lite mode requires less (1-4GB RAM, 8-16GB disk) but has no vector database."
$reply = Prompt-OrDefault "Do you want to continue anyway? (Y/n)" "y"
if ($reply -notmatch '^[Yy]') { Print-Info "Installation cancelled."; exit 1 }
$reply = (Prompt-OrDefault "Do you want to continue anyway? (Y/n)" "y").Trim().ToLower()
if ($reply -notmatch '^y') { Print-Info "Installation cancelled."; exit 1 }
Print-Info "Proceeding despite resource limitations..."
}
@@ -925,8 +926,8 @@ function Main {
if ($composeVersion -ne "unknown" -and (Compare-SemVer $composeVersion "2.24.0") -lt 0) {
Print-Warning "Docker Compose $composeVersion is older than 2.24.0 (required for env_file format)."
Print-Info "Update Docker Desktop or install a newer Docker Compose. Installation may fail."
$reply = Prompt-OrDefault "Continue anyway? (Y/n)" "y"
if ($reply -notmatch '^[Yy]') { exit 1 }
$reply = (Prompt-OrDefault "Continue anyway? (Y/n)" "y").Trim().ToLower()
if ($reply -notmatch '^y') { exit 1 }
}
$liteOverlayPath = Join-Path $deploymentDir $script:LiteComposeFile

View File

@@ -174,34 +174,42 @@ ensure_file() {
# --- Interactive prompt helpers ---
is_interactive() {
[[ "$NO_PROMPT" = false ]]
[[ "$NO_PROMPT" = false ]] && [[ -r /dev/tty ]] && [[ -w /dev/tty ]]
}
read_prompt_line() {
local prompt_text="$1"
if ! is_interactive; then
REPLY=""
return
fi
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
IFS= read -r REPLY < /dev/tty || REPLY=""
}
read_prompt_char() {
local prompt_text="$1"
if ! is_interactive; then
REPLY=""
return
fi
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
IFS= read -r -n 1 REPLY < /dev/tty || REPLY=""
printf "\n" > /dev/tty
}
prompt_or_default() {
local prompt_text="$1"
local default_value="$2"
if is_interactive; then
read -p "$prompt_text" -r REPLY
if [[ -z "$REPLY" ]]; then
REPLY="$default_value"
fi
else
REPLY="$default_value"
fi
read_prompt_line "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
}
prompt_yn_or_default() {
local prompt_text="$1"
local default_value="$2"
if is_interactive; then
read -p "$prompt_text" -n 1 -r
echo ""
if [[ -z "$REPLY" ]]; then
REPLY="$default_value"
fi
else
REPLY="$default_value"
fi
read_prompt_char "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
}
confirm_action() {
@@ -302,8 +310,8 @@ if [ "$DELETE_DATA_MODE" = true ]; then
echo " • All user data and documents"
echo ""
if is_interactive; then
read -p "Are you sure you want to continue? Type 'DELETE' to confirm: " -r
echo ""
prompt_or_default "Are you sure you want to continue? Type 'DELETE' to confirm: " ""
echo "" > /dev/tty
if [ "$REPLY" != "DELETE" ]; then
print_info "Operation cancelled."
exit 0
@@ -497,7 +505,7 @@ echo ""
if is_interactive; then
echo -e "${YELLOW}${BOLD}Please acknowledge and press Enter to continue...${NC}"
read -r
read_prompt_line ""
echo ""
else
echo -e "${YELLOW}${BOLD}Running in non-interactive mode - proceeding automatically...${NC}"

View File

@@ -45,7 +45,7 @@ function MemoryTagWithTooltip({
<MemoriesModal
initialTargetMemoryId={memoryId}
initialTargetIndex={memoryIndex}
highlightFirstOnOpen
highlightOnOpen
/>
</memoriesModal.Provider>
{memoriesModal.isOpen ? (
@@ -56,8 +56,14 @@ function MemoryTagWithTooltip({
side="bottom"
className="bg-background-neutral-00 text-text-01 shadow-md max-w-[17.5rem] p-1"
tooltip={
<Section flexDirection="column" gap={0.25} height="auto">
<div className="p-1 w-full">
<Section
flexDirection="column"
alignItems="start"
padding={0.25}
gap={0.25}
height="auto"
>
<div className="p-1">
<Text as="p" secondaryBody text03>
{memoryText}
</Text>
@@ -66,6 +72,7 @@ function MemoryTagWithTooltip({
icon={SvgAddLines}
title={operationLabel}
sizePreset="secondary"
paddingVariant="sm"
variant="body"
prominence="muted"
rightChildren={

View File

@@ -103,7 +103,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
<MemoriesModal
initialTargetMemoryId={memoryId}
initialTargetIndex={index}
highlightFirstOnOpen
highlightOnOpen
/>
</memoriesModal.Provider>
{memoryText ? (

View File

@@ -125,7 +125,7 @@ export const MAX_FILES_TO_SHOW = 3;
export const MOBILE_SIDEBAR_BREAKPOINT_PX = 640;
export const DESKTOP_SMALL_BREAKPOINT_PX = 912;
export const DESKTOP_MEDIUM_BREAKPOINT_PX = 1232;
export const DEFAULT_AGENT_AVATAR_SIZE_PX = 18;
export const DEFAULT_AVATAR_SIZE_PX = 18;
export const HORIZON_DISTANCE_PX = 800;
export const LOGO_FOLDED_SIZE_PX = 24;
export const LOGO_UNFOLDED_SIZE_PX = 88;

55
web/src/lib/user.test.ts Normal file
View File

@@ -0,0 +1,55 @@
import { getUserInitials } from "@/lib/user";
describe("getUserInitials", () => {
it("returns first letters of first two name parts", () => {
expect(getUserInitials("Alice Smith", "alice@example.com")).toBe("AS");
});
it("returns first two chars of a single-word name", () => {
expect(getUserInitials("Alice", "alice@example.com")).toBe("AL");
});
it("handles three-word names (uses first two)", () => {
expect(getUserInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
});
it("falls back to email local part with dot separator", () => {
expect(getUserInitials(null, "alice.smith@example.com")).toBe("AS");
});
it("falls back to email local part with underscore separator", () => {
expect(getUserInitials(null, "alice_smith@example.com")).toBe("AS");
});
it("falls back to email local part with hyphen separator", () => {
expect(getUserInitials(null, "alice-smith@example.com")).toBe("AS");
});
it("uses first two chars of email local if no separator", () => {
expect(getUserInitials(null, "alice@example.com")).toBe("AL");
});
it("returns null for empty email local part", () => {
expect(getUserInitials(null, "@example.com")).toBeNull();
});
it("uppercases the result", () => {
expect(getUserInitials("john doe", "jd@test.com")).toBe("JD");
});
it("trims whitespace from name", () => {
expect(getUserInitials(" Alice Smith ", "a@test.com")).toBe("AS");
});
it("returns null for numeric name parts", () => {
expect(getUserInitials("Alice 1st", "x@test.com")).toBeNull();
});
it("returns null for numeric email", () => {
expect(getUserInitials(null, "42@domain.com")).toBeNull();
});
it("falls back to email when name has non-alpha chars", () => {
expect(getUserInitials("A1", "alice@example.com")).toBe("AL");
});
});

View File

@@ -128,3 +128,54 @@ export function getUserDisplayName(user: User | null): string {
// If nothing works, then fall back to anonymous user name
return "Anonymous";
}
/**
* Derive display initials from a user's name or email.
*
* - If a name is provided, uses the first letter of the first two words.
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
* - Returns `null` when no valid alpha initials can be derived.
*/
export function getUserInitials(
name: string | null,
email: string
): string | null {
if (name) {
const words = name.trim().split(/\s+/);
if (words.length >= 2) {
const first = words[0]?.[0];
const second = words[1]?.[0];
if (first && second) {
const result = (first + second).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
return null;
}
if (name.trim().length >= 1) {
const result = name.trim().slice(0, 2).toUpperCase();
if (/^[A-Z]{1,2}$/.test(result)) return result;
}
}
const local = email.split("@")[0];
if (!local || local.length === 0) return null;
const parts = local.split(/[._-]/);
if (parts.length >= 2) {
const first = parts[0]?.[0];
const second = parts[1]?.[0];
if (first && second) {
const result = (first + second).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
return null;
}
if (local.length >= 2) {
const result = local.slice(0, 2).toUpperCase();
if (/^[A-Z]{2}$/.test(result)) return result;
}
if (local.length === 1) {
const result = local.toUpperCase();
if (/^[A-Z]$/.test(result)) return result;
}
return null;
}

View File

@@ -4,10 +4,7 @@ import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
import { buildImgUrl } from "@/app/app/components/files/images/utils";
import { OnyxIcon } from "@/components/icons/icons";
import { useSettingsContext } from "@/providers/SettingsProvider";
import {
DEFAULT_AGENT_AVATAR_SIZE_PX,
DEFAULT_AGENT_ID,
} from "@/lib/constants";
import { DEFAULT_AVATAR_SIZE_PX, DEFAULT_AGENT_ID } from "@/lib/constants";
import CustomAgentAvatar from "@/refresh-components/avatars/CustomAgentAvatar";
import Image from "next/image";
@@ -18,7 +15,7 @@ export interface AgentAvatarProps {
export default function AgentAvatar({
agent,
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
size = DEFAULT_AVATAR_SIZE_PX,
...props
}: AgentAvatarProps) {
const settings = useSettingsContext();

View File

@@ -4,7 +4,7 @@ import { cn } from "@/lib/utils";
import type { IconProps } from "@opal/types";
import Text from "@/refresh-components/texts/Text";
import Image from "next/image";
import { DEFAULT_AGENT_AVATAR_SIZE_PX } from "@/lib/constants";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import {
SvgActivitySmall,
SvgAudioEqSmall,
@@ -96,7 +96,7 @@ export default function CustomAgentAvatar({
src,
iconName,
size = DEFAULT_AGENT_AVATAR_SIZE_PX,
size = DEFAULT_AVATAR_SIZE_PX,
}: CustomAgentAvatarProps) {
if (src) {
return (

View File

@@ -0,0 +1,52 @@
import { SvgUser } from "@opal/icons";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import { getUserInitials } from "@/lib/user";
import Text from "@/refresh-components/texts/Text";
import type { User } from "@/lib/types";
export interface UserAvatarProps {
user: User;
size?: number;
}
export default function UserAvatar({
user,
size = DEFAULT_AVATAR_SIZE_PX,
}: UserAvatarProps) {
const initials = getUserInitials(
user.personalization?.name ?? null,
user.email
);
if (!initials) {
return (
<div
role="img"
aria-label={`${user.email} avatar`}
className="flex items-center justify-center rounded-full bg-background-tint-01"
style={{ width: size, height: size }}
>
<SvgUser size={size * 0.55} className="stroke-text-03" aria-hidden />
</div>
);
}
return (
<div
role="img"
aria-label={`${user.email} avatar`}
className="flex items-center justify-center rounded-full bg-background-neutral-inverted-00"
style={{ width: size, height: size }}
>
<Text
inverted
secondaryAction
text05
className="select-none"
style={{ fontSize: size * 0.4 }}
>
{initials}
</Text>
</div>
);
}

View File

@@ -54,8 +54,10 @@ function MemoryItem({
const wrapperRef = useRef<HTMLDivElement>(null);
useEffect(() => {
if (shouldFocus) {
textareaRef.current?.focus();
if (shouldFocus && textareaRef.current) {
const el = textareaRef.current;
el.focus();
el.selectionStart = el.selectionEnd = el.value.length;
onFocused?.();
}
}, [shouldFocus, onFocused]);
@@ -63,8 +65,10 @@ function MemoryItem({
useEffect(() => {
if (!shouldHighlight) return;
wrapperRef.current?.scrollIntoView({ block: "center", behavior: "smooth" });
textareaRef.current?.focus();
wrapperRef.current?.scrollIntoView({
block: "start",
behavior: "smooth",
});
setIsHighlighting(true);
const timer = setTimeout(() => {
@@ -79,10 +83,10 @@ function MemoryItem({
<div
ref={wrapperRef}
className={cn(
"rounded-08 hover:bg-background-tint-00 w-full p-0.5",
"rounded-08 w-full p-0.5 border border-transparent",
"transition-colors ",
isHighlighting &&
"bg-action-link-01 border border-action-link-05 duration-700"
"bg-action-link-01 hover:bg-action-link-01 border-action-link-05 duration-700"
)}
>
<Section gap={0.25} alignItems="start">
@@ -100,7 +104,7 @@ function MemoryItem({
rows={3}
maxLength={MAX_MEMORY_LENGTH}
resizable={false}
className={cn(!isFocused && "bg-transparent")}
className="bg-background-tint-01 hover:bg-background-tint-00 focus-within:bg-background-tint-00"
/>
<Disabled disabled={!memory.content.trim() && memory.isNew}>
<Button
@@ -122,13 +126,29 @@ function MemoryItem({
);
}
function resolveTargetMemoryId(
targetMemoryId: number | null | undefined,
targetIndex: number | null | undefined,
memories: MemoryItem[]
): number | null {
if (targetMemoryId != null) return targetMemoryId;
if (targetIndex != null && memories.length > 0) {
// Backend index is ASC (oldest-first), frontend displays DESC (newest-first)
const descIdx = memories.length - 1 - targetIndex;
return memories[descIdx]?.id ?? null;
}
return null;
}
interface MemoriesModalProps {
memories?: MemoryItem[];
onSaveMemories?: (memories: MemoryItem[]) => Promise<boolean>;
onClose?: () => void;
initialTargetMemoryId?: number | null;
initialTargetIndex?: number | null;
highlightFirstOnOpen?: boolean;
highlightOnOpen?: boolean;
}
export default function MemoriesModal({
@@ -137,7 +157,7 @@ export default function MemoriesModal({
onClose,
initialTargetMemoryId,
initialTargetIndex,
highlightFirstOnOpen = false,
highlightOnOpen = false,
}: MemoriesModalProps) {
const close = useModalClose(onClose);
const [focusMemoryId, setFocusMemoryId] = useState<number | null>(null);
@@ -182,24 +202,16 @@ export default function MemoriesModal({
);
useEffect(() => {
if (initialTargetMemoryId != null) {
// Direct DB id available — use it
setHighlightMemoryId(initialTargetMemoryId);
} else if (initialTargetIndex != null && effectiveMemories.length > 0) {
// Backend index is ASC (oldest-first), but the frontend displays DESC
// (newest-first). Convert: descIdx = totalCount - 1 - ascIdx
const descIdx = effectiveMemories.length - 1 - initialTargetIndex;
const target = effectiveMemories[descIdx];
if (target) {
setHighlightMemoryId(target.id);
}
} else if (
highlightFirstOnOpen &&
effectiveMemories.length > 0 &&
effectiveMemories[0]
) {
// Fallback: highlight the first displayed item (newest)
setHighlightMemoryId(effectiveMemories[0].id);
const targetId = resolveTargetMemoryId(
initialTargetMemoryId,
initialTargetIndex,
effectiveMemories
);
if (targetId == null) return;
setFocusMemoryId(targetId);
if (highlightOnOpen) {
setHighlightMemoryId(targetId);
}
}, [initialTargetMemoryId, initialTargetIndex]);

View File

@@ -26,7 +26,7 @@ import type {
StatusFilter,
StatusCountMap,
} from "./interfaces";
import { getInitials } from "./utils";
import { getUserInitials } from "@/lib/user";
// ---------------------------------------------------------------------------
// Column renderers
@@ -76,7 +76,8 @@ function buildColumns(onMutate: () => void) {
return [
tc.qualifier({
content: "avatar-user",
getInitials: (row) => getInitials(row.personal_name, row.email),
getInitials: (row) =>
getUserInitials(row.personal_name, row.email) ?? "?",
selectable: false,
}),
tc.column("email", {

View File

@@ -1,43 +0,0 @@
import { getInitials } from "./utils";
describe("getInitials", () => {
it("returns first letters of first two name parts", () => {
expect(getInitials("Alice Smith", "alice@example.com")).toBe("AS");
});
it("returns first two chars of a single-word name", () => {
expect(getInitials("Alice", "alice@example.com")).toBe("AL");
});
it("handles three-word names (uses first two)", () => {
expect(getInitials("Alice B. Smith", "alice@example.com")).toBe("AB");
});
it("falls back to email local part with dot separator", () => {
expect(getInitials(null, "alice.smith@example.com")).toBe("AS");
});
it("falls back to email local part with underscore separator", () => {
expect(getInitials(null, "alice_smith@example.com")).toBe("AS");
});
it("falls back to email local part with hyphen separator", () => {
expect(getInitials(null, "alice-smith@example.com")).toBe("AS");
});
it("uses first two chars of email local if no separator", () => {
expect(getInitials(null, "alice@example.com")).toBe("AL");
});
it("returns ? for empty email local part", () => {
expect(getInitials(null, "@example.com")).toBe("?");
});
it("uppercases the result", () => {
expect(getInitials("john doe", "jd@test.com")).toBe("JD");
});
it("trims whitespace from name", () => {
expect(getInitials(" Alice Smith ", "a@test.com")).toBe("AS");
});
});

View File

@@ -1,23 +0,0 @@
/**
* Derive display initials from a user's name or email.
*
* - If a name is provided, uses the first letter of the first two words.
* - Falls back to the email local part, splitting on `.`, `_`, or `-`.
* - Returns at most 2 uppercase characters.
*/
export function getInitials(name: string | null, email: string): string {
if (name) {
const parts = name.trim().split(/\s+/);
if (parts.length >= 2) {
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
}
return name.slice(0, 2).toUpperCase();
}
const local = email.split("@")[0];
if (!local) return "?";
const parts = local.split(/[._-]/);
if (parts.length >= 2) {
return ((parts[0]?.[0] ?? "") + (parts[1]?.[0] ?? "")).toUpperCase();
}
return local.slice(0, 2).toUpperCase();
}